diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index c0add220..8197d3a7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,20 +2,17 @@ name: Go on: [push, pull_request] -env: - GOTOOLCHAIN: local - jobs: lint: runs-on: ubuntu-latest name: Lint (latest) steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: - go-version: "1.26" + go-version: "1.23" cache: true - name: Install libolm @@ -24,25 +21,27 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest - go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - - name: Run pre-commit - uses: pre-commit/action@v3.0.1 + - name: Install pre-commit + run: pip install pre-commit + + - name: Lint + run: pre-commit run -a build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.25", "1.26"] - name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm) + go-version: ["1.22", "1.23"] + name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, libolm) steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: true @@ -61,29 +60,30 @@ jobs: - name: Test run: go test -json -v ./... 2>&1 | gotestfmt - - name: Test (jsonv2) - env: - GOEXPERIMENT: jsonv2 - run: go test -json -v ./... 2>&1 | gotestfmt - build-goolm: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.25", "1.26"] - name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm) + go-version: ["1.22", "1.23"] + name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, goolm) steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: true + - name: Set up gotestfmt + uses: GoTestTools/gotestfmt-action@v2 + with: + token: ${{ secrets.GITHUB_TOKEN }} + - name: Build - run: | - rm -rf crypto/libolm - go build -tags=goolm -v ./... + run: go build -tags=goolm -v ./... + + - name: Test + run: go test -tags=goolm -json -v ./... 2>&1 | gotestfmt diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml deleted file mode 100644 index 9a9e7375..00000000 --- a/.github/workflows/stale.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: 'Lock old issues' - -on: - schedule: - - cron: '0 6 * * *' - workflow_dispatch: - -permissions: - issues: write -# pull-requests: write -# discussions: write - -concurrency: - group: lock-threads - -jobs: - lock-stale: - runs-on: ubuntu-latest - steps: - - uses: dessant/lock-threads@v6 - id: lock - with: - issue-inactive-days: 90 - process-only: issues - - name: Log processed threads - run: | - if [ '${{ steps.lock.outputs.issues }}' ]; then - echo "Issues:" && echo '${{ steps.lock.outputs.issues }}' | jq -r '.[] | "https://github.com/\(.owner)/\(.repo)/issues/\(.issue_number)"' - fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 616fccb2..1ef1b112 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 + rev: v4.6.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,7 +9,7 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.4 + rev: v1.0.0-rc.1 hooks: - id: go-imports-repo args: @@ -18,12 +18,8 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy - - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go - rev: v0.4.2 + rev: v0.3.1 hooks: - id: prevent-literal-http-methods - - id: zerolog-ban-global-log - - id: zerolog-ban-msgf - - id: zerolog-use-stringer diff --git a/CHANGELOG.md b/CHANGELOG.md index f2829199..819f69d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,500 +1,3 @@ -## v0.26.3 (2026-02-16) - -* Bumped minimum Go version to 1.25. -* *(client)* Added fields for sending [MSC4354] sticky events. -* *(bridgev2)* Added automatic message request accepting when sending message. -* *(mediaproxy)* Added support for federation thumbnail endpoint. -* *(crypto/ssss)* Improved support for recovery keys with slightly broken - metadata. -* *(crypto)* Changed key import to call session received callback even for - sessions that already exist in the database. -* *(appservice)* Fixed building websocket URL accidentally using file path - separators instead of always `/`. -* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field. -* *(client)* Fixed incorrect context usage in async uploads. -* *(crypto)* Fixed panic when passing invalid input to megolm message index - parser used for debugging. -* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned - up properly. - -[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354 - -## v0.26.2 (2026-01-16) - -* *(bridgev2)* Added chunked portal deletion to avoid database locks when - deleting large portals. -* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata - as per [MSC4392]. -* *(bridgev2/login)* Added `default_value` for user input fields. -* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested - HTTP client settings and to reset active connections of the network connector. -* *(bridgev2)* Added interface to let network connectors get the provisioning - API HTTP router and add new endpoints. -* *(event)* Added blurhash field to Beeper link preview objects. -* *(event)* Added [MSC4391] support for bot commands. -* *(event)* Dropped [MSC4332] support for bot commands. -* *(client)* Changed media download methods to return an error if the provided - MXC URI is empty. -* *(client)* Stabilized support for [MSC4323]. -* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events. -* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with - a portal re-ID call. - -[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391 -[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392 - -## v0.26.1 (2025-12-16) - -* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return - the mime type from the callback rather than in the return get media return - value. The callback can now also redirect the caller to a different file. -* *(federation)* Added join/knock/leave functions - (thanks to [@nexy7574] in [#422]). -* *(federation/eventauth)* Fixed various incorrect checks. -* *(client)* Added backoff for retrying media uploads to external URLs - (with MSC3870). -* *(bridgev2/config)* Added support for overriding config fields using - environment variables. -* *(bridgev2/commands)* Added command to mute chat on remote network. -* *(bridgev2)* Added interface for network connectors to redirect to a different - user ID when handling an invite from Matrix. -* *(bridgev2)* Added interface for signaling message request status of portals. -* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag - is set in chat info. -* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if - bridging the new one is successful. -* *(bridgev2/mxmain)* Improved error message when trying to run bridge with - pre-megabridge database when no database migration exists. -* *(bridgev2)* Improved reliability of database migration when enabling split - portals. -* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats. -* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public - portal rooms. -* *(bridgev2)* Stopped hardcoding room versions in favor of checking - server capabilities to determine appropriate `/createRoom` parameters. - -[#422]: https://github.com/mautrix/go/pull/422 - -## v0.26.0 (2025-11-16) - -* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` - has been able to do the same for a while now. -* *(client,federation)* Added size limits for responses to make it safer to send - requests to untrusted servers. -* *(client)* Added wrapper for `/admin/whois` client API - (thanks to [@nexy7574] in [#411]). -* *(synapseadmin)* Added `force_purge` option to DeleteRoom - (thanks to [@nexy7574] in [#420]). -* *(statestore)* Added saving join rules for rooms. -* *(bridgev2)* Added optional automatic rollback of room state if bridging the - change to the remote network fails. -* *(bridgev2)* Added management room notices if transient disconnect state - doesn't resolve within 3 minutes. -* *(bridgev2)* Added interface to signal that certain participants couldn't be - invited when creating a group. -* *(bridgev2)* Added `select` type for user input fields in login. -* *(bridgev2)* Added interface to let network connector customize personal - filtering space. -* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to - other bots. -* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever - possible. -* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, - and encrypted files. -* *(bridgev2/commands)* Added command to resync a single portal. -* *(bridgev2/commands)* Added create group command. -* *(bridgev2/config)* Added option to limit maximum number of logins. -* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room - is public. -* *(bridgev2/disappear)* Changed read receipt handling to only start - disappearing timers for messages up to the read message (note: may not work in - all cases if the read receipt points at an unknown event). -* *(event/reply)* Changed plaintext reply fallback removal to only happen when - an HTML reply fallback is removed successfully. -* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run. -* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages. -* *(federation)* Fixed HTTP method for sending transactions - (thanks to [@nexy7574] in [#426]). -* *(federation)* Fixed response body being closed even when using `DontReadBody` - parameter. -* *(federation)* Fixed validating auth for requests with query params. -* *(federation/eventauth)* Fixed typo causing restricted joins to not work. - -[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 -[#411]: github.com/mautrix/go/pull/411 -[#420]: github.com/mautrix/go/pull/420 -[#426]: github.com/mautrix/go/pull/426 - -## v0.25.2 (2025-10-16) - -* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into - `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old - behavior, but most users likely want the relaxed version, as there are real - users whose user IDs aren't valid under the strict rules. -* *(crypto)* Added helper methods for generating and verifying with recovery - keys. -* *(bridgev2/matrix)* Added config option to automatically generate a recovery - key for the bridge bot and self-sign the bridge's device. -* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode - for encryption with standard servers like Synapse. -* *(bridgev2)* Added optional support for implicit read receipts. -* *(bridgev2)* Added interface for deleting chats on remote network. -* *(bridgev2)* Added local enforcement of media duration and size limits. -* *(bridgev2)* Extended event duration logging to log any event taking too long. -* *(bridgev2)* Improved validation in group creation provisioning API. -* *(event)* Added event type constant for poll end events. -* *(client)* Added wrapper for searching user directory. -* *(client)* Improved support for managing [MSC4140] delayed events. -* *(crypto/helper)* Changed default sync handling to not block on waiting for - decryption keys. On initial sync, keys won't be requested at all by default. -* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1). -* *(bridgev2)* Fixed various bugs with migrating to split portals. -* *(event)* Fixed poll start events having incorrect null `m.relates_to`. -* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. -* *(federation)* Fixed various bugs in event auth. - -## v0.25.1 (2025-09-16) - -* *(client)* Fixed HTTP method of delete devices API call - (thanks to [@fmseals] in [#393]). -* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints - (thanks to [@nexy7574] in [#407]). -* *(client)* Stabilized support for extensible profiles. -* *(client)* Stabilized support for `state_after` in sync. -* *(client)* Removed deprecated MSC2716 requests. -* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if - the content struct doesn't implement `Relatable`. -* *(crypto)* Changed olm unwedging to ignore newly created sessions if they - haven't been used successfully in either direction. -* *(federation)* Added utilities for generating, parsing, validating and - authorizing PDUs. - * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2` -* *(event)* Added `is_animated` flag from [MSC4230] to file info. -* *(event)* Added types for [MSC4332]: In-room bot commands. -* *(event)* Added missing poll end event type for [MSC3381]. -* *(appservice)* Fixed URLs not being escaped properly when using unix socket - for homeserver connections. -* *(format)* Added more helpers for forming markdown links. -* *(event,bridgev2)* Added support for Beeper's disappearing message state event. -* *(bridgev2)* Redesigned group creation interface and added support in commands - and provisioning API. -* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to - get an old event. The method is best effort only, as some configurations don't - allow fetching old events. -* *(bridgev2)* Added shared logic for provisioning that can be reused by the - API, commands and other sources. -* *(bridgev2)* Fixed mentions and URL previews not being copied over when - caption and media are merged. -* *(bridgev2)* Removed config option to change provisioning API prefix, which - had already broken in the previous release. - -[@fmseals]: https://github.com/fmseals -[#393]: https://github.com/mautrix/go/pull/393 -[#407]: https://github.com/mautrix/go/pull/407 -[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381 -[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230 -[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332 - -## v0.25.0 (2025-08-16) - -* Bumped minimum Go version to 1.24. -* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux - with standard library ServeMux. -* *(client,bridgev2)* Added support for creator power in room v12. -* *(client)* Added option to not set `User-Agent` header for improved Wasm - compatibility. -* *(bridgev2)* Added support for following tombstones. -* *(bridgev2)* Added interface for getting arbitrary state event from Matrix. -* *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't - use too many resources even if there are a large number of messages. -* *(bridgev2/commands)* Added support for canceling QR login with `cancel` - command. -* *(client)* Added option to override HTTP client used for .well-known - resolution. -* *(crypto/backup)* Added method for encrypting key backup session without - private keys. -* *(event->id)* Moved room version type and constants to id package. -* *(bridgev2)* Bots in DM portals will now be added to the functional members - state event to hide them from the room name calculation. -* *(bridgev2)* Changed message delete handling to ignore "delete for me" events - if there are multiple Matrix users in the room. -* *(format/htmlparser)* Changed text processing to collapse multiple spaces into - one when outside `pre`/`code` tags. -* *(format/htmlparser)* Removed link suffix in plaintext output when link text - is only missing protocol part of href. - * e.g. `example.com` will turn into - `example.com` rather than `example.com (https://example.com)` -* *(appservice)* Switched appservice websockets from gorilla/websocket to - coder/websocket. -* *(bridgev2/matrix)* Fixed encryption key sharing not ignoring ghosts properly. -* *(crypto/attachments)* Fixed hash check when decrypting file streams. -* *(crypto)* Removed unnecessary `AlreadyShared` error in `ShareGroupSession`. - The function will now act as if it was successful instead. - -## v0.24.2 (2025-07-16) - -* *(bridgev2)* Added support for return values from portal event handlers. Note - that the return value will always be "queued" unless the event buffer is - disabled. -* *(bridgev2)* Added support for [MSC4144] per-message profile passthrough in - relay mode. -* *(bridgev2)* Added option to auto-reconnect logins after a certain period if - they hit an `UNKNOWN_ERROR` state. -* *(bridgev2)* Added analytics for event handler panics. -* *(bridgev2)* Changed new room creation to hardcode room v11 to avoid v12 rooms - being created before proper support for them can be added. -* *(bridgev2)* Changed queuing events to block instead of dropping events if the - buffer is full. -* *(bridgev2)* Fixed assumption that replies to unknown messages are cross-room. -* *(id)* Fixed server name validation not including ports correctly - (thanks to [@krombel] in [#392]). -* *(federation)* Fixed base64 algorithm in signature generation. -* *(event)* Fixed [MSC4144] fallbacks not being removed from edits. - -[@krombel]: https://github.com/krombel -[#392]: https://github.com/mautrix/go/pull/392 - -## v0.24.1 (2025-06-16) - -* *(commands)* Added framework for using reactions as buttons that execute - command handlers. -* *(client)* Added wrapper for `/relations` endpoints. -* *(client)* Added support for stable version of room summary endpoint. -* *(client)* Fixed parsing URL preview responses where width/height are strings. -* *(federation)* Fixed bugs in server auth. -* *(id)* Added utilities for validating server names. -* *(event)* Fixed incorrect empty `entity` field when sending hashed moderation - policy events. -* *(event)* Added [MSC4293] redact events field to member events. -* *(event)* Added support for fallbacks in [MSC4144] per-message profiles. -* *(format)* Added `MarkdownLink` and `MarkdownMention` utility functions for - generating properly escaped markdown. -* *(synapseadmin)* Added support for synchronous (v1) room delete endpoint. -* *(synapseadmin)* Changed `Client` struct to not embed the `mautrix.Client`. - This is a breaking change if you were relying on accessing non-admin functions - from the admin client. -* *(bridgev2/provisioning)* Fixed `/display_and_wait` not passing through errors - from the network connector properly. -* *(bridgev2/crypto)* Fixed encryption not working if the user's ID had the same - prefix as the bridge ghosts (e.g. `@whatsappbridgeuser:example.com` with a - `@whatsapp_` prefix). -* *(bridgev2)* Fixed portals not being saved after creating a DM portal from a - Matrix DM invite. -* *(bridgev2)* Added config option to determine whether cross-room replies - should be bridged. -* *(appservice)* Fixed `EnsureRegistered` not being called when sending a custom - member event for the controlled user. - -[MSC4293]: https://github.com/matrix-org/matrix-spec-proposals/pull/4293 - -## v0.24.0 (2025-05-16) - -* *(commands)* Added generic framework for implementing bot commands. -* *(client)* Added support for specifying maximum number of HTTP retries using - a context value instead of having to call `MakeFullRequest` manually. -* *(client,federation)* Added methods for fetching room directories. -* *(federation)* Added support for server side of request authentication. -* *(synapseadmin)* Added wrapper for the account suspension endpoint. -* *(format)* Added method for safely wrapping a string in markdown inline code. -* *(crypto)* Added method to import key backup without persisting to database, - to allow the client more control over the process. -* *(bridgev2)* Added viewing chat interface to signal when the user is viewing - a given chat. -* *(bridgev2)* Added option to pass through transaction ID from client when - sending messages to remote network. -* *(crypto)* Fixed unnecessary error log when decrypting dummy events used for - unwedging Olm sessions. -* *(crypto)* Fixed `forwarding_curve25519_key_chain` not being set consistently - when backing up keys. -* *(event)* Fixed marshaling legacy VoIP events with no version field. -* *(bridgev2)* Fixed disappearing message references not being deleted when the - portal is deleted. -* *(bridgev2)* Fixed read receipt bridging not ignoring fake message entries - and causing unnecessary error logs. - -## v0.23.3 (2025-04-16) - -* *(commands)* Added generic command processing framework for bots. -* *(client)* Added `allowed_room_ids` field to room summary responses - (thanks to [@nexy7574] in [#367]). -* *(bridgev2)* Added support for custom timeouts on outgoing messages which have - to wait for a remote echo. -* *(bridgev2)* Added automatic typing stop event if the ghost user had sent a - typing event before a message. -* *(bridgev2)* The saved management room is now cleared if the user leaves the - room, allowing the next DM to be automatically marked as a management room. -* *(bridge)* Removed deprecated fallback package for bridge statuses. - The status package is now only available under bridgev2. - -[#367]: https://github.com/mautrix/go/pull/367 - -## v0.23.2 (2025-03-16) - -* **Breaking change *(bridge)*** Removed legacy bridge module. -* **Breaking change *(event)*** Changed `m.federate` field in room create event - content to a pointer to allow detecting omitted values. -* *(bridgev2/commands)* Added `set-management-room` command to set a new - management room. -* *(bridgev2/portal)* Changed edit bridging to ignore remote edits if the - original sender on Matrix can't be puppeted. -* *(bridgv2)* Added config option to disable bridging `m.notice` messages. -* *(appservice/http)* Switched access token validation to use constant time - comparisons. -* *(event)* Added support for [MSC3765] rich text topics. -* *(event)* Added fields to policy list event contents for [MSC4204] and - [MSC4205]. -* *(client)* Added method for getting the content of a redacted event using - [MSC2815]. -* *(client)* Added methods for sending and updating [MSC4140] delayed events. -* *(client)* Added support for [MSC4222] in sync payloads. -* *(crypto/cryptohelper)* Switched to using `sqlite3-fk-wal` instead of plain - `sqlite3` by default. -* *(crypto/encryptolm)* Added generic method for encrypting to-device events. -* *(crypto/ssss)* Fixed panic if server-side key metadata is corrupted. -* *(crypto/sqlstore)* Fixed error when marking over 32 thousand device lists - as outdated on SQLite. - -[MSC2815]: https://github.com/matrix-org/matrix-spec-proposals/pull/2815 -[MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 -[MSC4140]: https://github.com/matrix-org/matrix-spec-proposals/pull/4140 -[MSC4204]: https://github.com/matrix-org/matrix-spec-proposals/pull/4204 -[MSC4205]: https://github.com/matrix-org/matrix-spec-proposals/pull/4205 -[MSC4222]: https://github.com/matrix-org/matrix-spec-proposals/pull/4222 - -## v0.23.1 (2025-02-16) - -* *(client)* Added `FullStateEvent` method to get a state event including - metadata (using the `?format=event` query parameter). -* *(client)* Added wrapper method for [MSC4194]'s redact endpoint. -* *(pushrules)* Fixed content rules not considering word boundaries and being - case-sensitive. -* *(crypto)* Fixed bugs that would cause key exports to fail for no reason. -* *(crypto)* Deprecated `ResolveTrust` in favor of `ResolveTrustContext`. -* *(crypto)* Stopped accepting secret shares from unverified devices. -* **Breaking change *(crypto)*** Changed `GetAndVerifyLatestKeyBackupVersion` - to take an optional private key parameter. The method will now trust the - public key if it matches the provided private key even if there are no valid - signatures. -* **Breaking change *(crypto)*** Added context parameter to `IsDeviceTrusted`. - -[MSC4194]: https://github.com/matrix-org/matrix-spec-proposals/pull/4194 - -## v0.23.0 (2025-01-16) - -* **Breaking change *(client)*** Changed `JoinRoom` parameters to allow multiple - `via`s. -* **Breaking change *(bridgev2)*** Updated capability system. - * The return type of `NetworkAPI.GetCapabilities` is now different. - * Media type capabilities are enforced automatically by bridgev2. - * Capabilities are now sent to Matrix rooms using the - `com.beeper.room_features` state event. -* *(client)* Added `GetRoomSummary` to implement [MSC3266]. -* *(client)* Added support for arbitrary profile fields to implement [MSC4133] - (thanks to [@nexy7574] in [#337]). -* *(crypto)* Started storing olm message hashes to prevent decryption errors - if messages are repeated (e.g. if the app crashes right after decrypting). -* *(crypto)* Improved olm session unwedging to check when the last session was - created instead of only relying on an in-memory map. -* *(crypto/verificationhelper)* Fixed emoji verification not doing cross-signing - properly after a successful verification. -* *(bridgev2/config)* Moved MSC4190 flag from `appservice` to `encryption`. -* *(bridgev2/space)* Fixed failing to add rooms to spaces if the room create - call was made with a temporary context. -* *(bridgev2/commands)* Changed `help` command to hide commands which require - interfaces that aren't implemented by the network connector. -* *(bridgev2/matrixinterface)* Moved deterministic room ID generation to Matrix - connector. -* *(bridgev2)* Fixed service member state event not being set correctly when - creating a DM by inviting a ghost user. -* *(bridgev2)* Fixed `RemoteReactionSync` events replacing all reactions every - time instead of only changed ones. - -[MSC3266]: https://github.com/matrix-org/matrix-spec-proposals/pull/3266 -[MSC4133]: https://github.com/matrix-org/matrix-spec-proposals/pull/4133 -[@nexy7574]: https://github.com/nexy7574 -[#337]: https://github.com/mautrix/go/pull/337 - -## v0.22.1 (2024-12-16) - -* *(crypto)* Added automatic cleanup when there are too many olm sessions with - a single device. -* *(crypto)* Added helper for getting cached device list with cross-signing - status. -* *(crypto/verificationhelper)* Added interface for persisting the state of - in-progress verifications. -* *(client)* Added `GetMutualRooms` wrapper for [MSC2666]. -* *(client)* Switched `JoinRoom` to use the `via` query param instead of - `server_name` as per [MSC4156]. -* *(bridgev2/commands)* Fixed `pm` command not actually starting the chat. -* *(bridgev2/interface)* Added separate network API interface for starting - chats with a Matrix ghost user. This allows treating internal user IDs - differently than arbitrary user-input strings. -* *(bridgev2/crypto)* Added support for [MSC4190] - (thanks to [@onestacked] in [#288]). - -[MSC2666]: https://github.com/matrix-org/matrix-spec-proposals/pull/2666 -[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 -[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 -[#288]: https://github.com/mautrix/go/pull/288 -[@onestacked]: https://github.com/onestacked - -## v0.22.0 (2024-11-16) - -* *(hicli)* Moved package into gomuks repo. -* *(bridgev2/commands)* Fixed cookie unescaping in login commands. -* *(bridgev2/portal)* Added special `DefaultChatName` constant to explicitly - reset portal names to the default (based on members). -* *(bridgev2/config)* Added options to disable room tag bridging. -* *(bridgev2/database)* Fixed reaction queries not including portal receiver. -* *(appservice)* Updated [MSC2409] stable registration field name from - `push_ephemeral` to `receive_ephemeral`. Homeserver admins must update - existing registrations manually. -* *(format)* Added support for `img` tags. -* *(format/mdext)* Added goldmark extensions for Matrix math and custom emojis. -* *(event/reply)* Removed support for generating reply fallbacks ([MSC2781]). -* *(pushrules)* Added support for `sender_notification_permission` condition - kind (used for `@room` mentions). -* *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. -* *(mediaproxy)* Added `GetMediaResponseCallback` and `GetMediaResponseFile` - to write proxied data directly to http response or temp file instead of - having to use an `io.Reader`. -* *(mediaproxy)* Dropped support for legacy media download endpoints. -* *(mediaproxy,bridgev2)* Made interface pass through query parameters. - -[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 - -## v0.21.1 (2024-10-16) - -* *(bridgev2)* Added more features and fixed bugs. -* *(hicli)* Added more features and fixed bugs. -* *(appservice)* Removed TLS support. A reverse proxy should be used if TLS - is needed. -* *(format/mdext)* Added goldmark extension to fix indented paragraphs when - disabling indented code block parser. -* *(event)* Added `Has` method for `Mentions`. -* *(event)* Added basic support for the unstable version of polls. - -## v0.21.0 (2024-09-16) - -* **Breaking change *(client)*** Dropped support for unauthenticated media. - Matrix v1.11 support is now required from the homeserver, although it's not - enforced using `/versions` as some servers don't advertise it. -* *(bridgev2)* Added more features and fixed bugs. -* *(appservice,crypto)* Added support for using MSC3202 for appservice - encryption. -* *(crypto/olm)* Made everything into an interface to allow side-by-side - testing of libolm and goolm, as well as potentially support vodozemac - in the future. -* *(client)* Fixed requests being retried even after context is canceled. -* *(client)* Added option to move `/sync` request logs to trace level. -* *(error)* Added `Write` and `WithMessage` helpers to `RespError` to make it - easier to use on servers. -* *(event)* Fixed `org.matrix.msc1767.audio` field allowing omitting the - duration and waveform. -* *(id)* Changed `MatrixURI` methods to not panic if the receiver is nil. -* *(federation)* Added limit to response size when fetching `.well-known` files. - ## v0.20.0 (2024-08-16) * Bumped minimum Go version to 1.22. diff --git a/README.md b/README.md index b1a2edf8..ac41ca78 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ # mautrix-go [![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix) -A Golang Matrix framework. Used by [gomuks](https://gomuks.app), -[go-neb](https://github.com/matrix-org/go-neb), -[mautrix-whatsapp](https://github.com/mautrix/whatsapp) +A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), +[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) @@ -14,10 +13,9 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) -* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc) +* End-to-end encryption support (incl. interactive SAS verification) * High-level module for building puppeting bridges -* Partial federation module (making requests, PDU processing and event authorization) -* A media proxy server which can be used to expose anything as a Matrix media repo +* High-level module for building chat clients * Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML diff --git a/appservice/appservice.go b/appservice/appservice.go index d7037ef6..90ace5d9 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2023 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,7 +19,8 @@ import ( "syscall" "time" - "github.com/coder/websocket" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" @@ -31,7 +32,7 @@ import ( // EventChannelSize is the size for the Events channel in Appservice instances. var EventChannelSize = 64 -var OTKChannelSize = 64 +var OTKChannelSize = 4 // Create creates a blank appservice instance. func Create() *AppService { @@ -42,7 +43,7 @@ func Create() *AppService { intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: mautrix.NewMemoryStateStore().(StateStore), - Router: http.NewServeMux(), + Router: mux.NewRouter(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, @@ -60,12 +61,12 @@ func Create() *AppService { DefaultHTTPRetries: 4, } - as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction) - as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom) - as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser) - as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing) - as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive) - as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady) + as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) + as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) + as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) + as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) + as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) + as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) return as } @@ -113,13 +114,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil) // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { - QueryAlias(alias id.RoomAlias) bool + QueryAlias(alias string) bool QueryUser(userID id.UserID) bool } type QueryHandlerStub struct{} -func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool { +func (qh *QueryHandlerStub) QueryAlias(alias string) bool { return false } @@ -127,7 +128,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool { return false } -type WebsocketHandler func(WebsocketCommand) (ok bool, data any) +type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) type StateStore interface { mautrix.StateStore @@ -159,7 +160,7 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *http.ServeMux + Router *mux.Router UserAgent string server *http.Server HTTPClient *http.Client @@ -178,6 +179,7 @@ type AppService struct { intentsLock sync.RWMutex ws *websocket.Conn + wsWriteLock sync.Mutex StopWebsocket func(error) websocketHandlers map[string]WebsocketHandler websocketHandlersLock sync.RWMutex @@ -222,6 +224,9 @@ type HostConfig struct { Hostname string `yaml:"hostname"` // Port is required when Hostname is an IP address, optional for unix sockets Port uint16 `yaml:"port"` + + TLSKey string `yaml:"tls_key,omitempty"` + TLSCert string `yaml:"tls_cert,omitempty"` } // Address gets the whole address of the Appservice. @@ -334,7 +339,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { } else if as.hsURLForClient.Scheme == "" { as.hsURLForClient.Scheme = "https" } - as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath() + as.hsURLForClient.RawPath = parsedURL.EscapedPath() jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar} @@ -360,7 +365,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { AccessToken: as.Registration.AppToken, UserAgent: as.UserAgent, StateStore: as.StateStore, - Log: as.Log.With().Stringer("as_user_id", userID).Logger(), + Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, SpecVersions: as.SpecVersions, diff --git a/appservice/http.go b/appservice/http.go index 27ce6288..47f6a282 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2023 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,9 +17,8 @@ import ( "syscall" "time" + "github.com/gorilla/mux" "github.com/rs/zerolog" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/exstrings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -60,8 +59,13 @@ func (as *AppService) listenUnix() error { } func (as *AppService) listenTCP() error { - as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") - return as.server.ListenAndServe() + if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") + return as.server.ListenAndServe() + } else { + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener with TLS") + return as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) + } } func (as *AppService) Stop() { @@ -79,9 +83,17 @@ func (as *AppService) Stop() { func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { - mautrix.MMissingToken.WithMessage("Missing access token").Write(w) - } else if !exstrings.ConstantTimeEqual(authHeader[len("Bearer "):], as.Registration.ServerToken) { - mautrix.MUnknownToken.WithMessage("Invalid access token").Write(w) + Error{ + ErrorCode: ErrUnknownToken, + HTTPStatus: http.StatusForbidden, + Message: "Missing access token", + }.Write(w) + } else if authHeader[len("Bearer "):] != as.Registration.ServerToken { + Error{ + ErrorCode: ErrUnknownToken, + HTTPStatus: http.StatusForbidden, + Message: "Incorrect access token", + }.Write(w) } else { isValid = true } @@ -94,15 +106,24 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } - txnID := r.PathValue("txnID") + vars := mux.Vars(r) + txnID := vars["txnID"] if len(txnID) == 0 { - mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w) + Error{ + ErrorCode: ErrNoTransactionID, + HTTPStatus: http.StatusBadRequest, + Message: "Missing transaction ID", + }.Write(w) return } defer r.Body.Close() body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 { - mautrix.MNotJSON.WithMessage("Failed to read response body").Write(w) + Error{ + ErrorCode: ErrNotJSON, + HTTPStatus: http.StatusBadRequest, + Message: "Missing request body", + }.Write(w) return } log := as.Log.With().Str("transaction_id", txnID).Logger() @@ -111,7 +132,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { ctx = log.WithContext(ctx) if as.txnIDC.IsProcessed(txnID) { // Duplicate transaction ID: no-op - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + WriteBlankOK(w) log.Debug().Msg("Ignoring duplicate transaction") return } @@ -120,10 +141,14 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { err = json.Unmarshal(body, &txn) if err != nil { log.Error().Err(err).Msg("Failed to parse transaction content") - mautrix.MBadJSON.WithMessage("Failed to parse transaction content").Write(w) + Error{ + ErrorCode: ErrBadJSON, + HTTPStatus: http.StatusBadRequest, + Message: "Failed to parse body JSON", + }.Write(w) } else { as.handleTransaction(ctx, txnID, &txn) - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + WriteBlankOK(w) } } @@ -201,7 +226,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def } err := evt.Content.ParseRaw(evt.Type) if errors.Is(err, event.ErrUnsupportedContentType) { - log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event") + log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event") } else if err != nil { log.Warn().Err(err). Str("event_id", evt.ID.String()). @@ -238,12 +263,16 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { return } - roomAlias := id.RoomAlias(r.PathValue("roomAlias")) + vars := mux.Vars(r) + roomAlias := vars["roomAlias"] ok := as.QueryHandler.QueryAlias(roomAlias) if ok { - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + WriteBlankOK(w) } else { - mautrix.MNotFound.WithMessage("Alias not found").Write(w) + Error{ + ErrorCode: ErrUnknown, + HTTPStatus: http.StatusNotFound, + }.Write(w) } } @@ -253,12 +282,16 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { return } - userID := id.UserID(r.PathValue("userID")) + vars := mux.Vars(r) + userID := id.UserID(vars["userID"]) ok := as.QueryHandler.QueryUser(userID) if ok { - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + WriteBlankOK(w) } else { - mautrix.MNotFound.WithMessage("User not found").Write(w) + Error{ + ErrorCode: ErrUnknown, + HTTPStatus: http.StatusNotFound, + }.Write(w) } } @@ -268,7 +301,11 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { } body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 || !json.Valid(body) { - mautrix.MNotJSON.WithMessage("Invalid or missing request body").Write(w) + Error{ + ErrorCode: ErrNotJSON, + HTTPStatus: http.StatusBadRequest, + Message: "Missing request body", + }.Write(w) return } @@ -276,21 +313,27 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { _ = json.Unmarshal(body, &txn) as.Log.Debug().Str("txn_id", txn.TxnID).Msg("Received ping from homeserver") - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte("{}")) } func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") if as.Live { - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + w.WriteHeader(http.StatusOK) } else { - exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) } + w.Write([]byte("{}")) } func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") if as.Ready { - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + w.WriteHeader(http.StatusOK) } else { - exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) } + w.Write([]byte("{}")) } diff --git a/appservice/intent.go b/appservice/intent.go index 5d43f190..6848f28c 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -86,7 +86,6 @@ func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { type EnsureJoinedParams struct { IgnoreCache bool BotOverride *mautrix.Client - Via []string } func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error { @@ -100,17 +99,11 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } - err := intent.EnsureRegistered(ctx) - if err != nil { + if err := intent.EnsureRegistered(ctx); err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } - var resp *mautrix.RespJoinRoom - if len(params.Via) > 0 { - resp, err = intent.JoinRoom(ctx, roomID.String(), &mautrix.ReqJoinRoom{Via: params.Via}) - } else { - resp, err = intent.JoinRoomByID(ctx, roomID) - } + resp, err := intent.JoinRoomByID(ctx, roomID) if err != nil { bot := intent.bot if params.BotOverride != nil { @@ -149,16 +142,12 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } -func (intent *IntentAPI) IsDoublePuppet() bool { - return intent.IsCustomPuppet && intent.as.DoublePuppetValue != "" -} - func (intent *IntentAPI) AddDoublePuppetValue(into any) any { return intent.AddDoublePuppetValueWithTS(into, 0) } func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { - if !intent.IsDoublePuppet() { + if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" { return into } // Only use ts deduplication feature with appservice double puppeting @@ -214,45 +203,38 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { } } -func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) } -func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { - return nil, err - } - if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { - return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") - } - contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...) -} - -// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } -func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - } else if err := intent.EnsureRegistered(ctx); err != nil { - return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) } -// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) + return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) } func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { @@ -311,7 +293,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) - return &mautrix.RespJoinRoom{RoomID: roomID}, err + return &mautrix.RespJoinRoom{}, err } return intent.Client.JoinRoomByID(ctx, roomID) } @@ -380,24 +362,6 @@ func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id return member } -func (intent *IntentAPI) FillPowerLevelCreateEvent(ctx context.Context, roomID id.RoomID, pl *event.PowerLevelsEventContent) error { - if pl.CreateEvent != nil { - return nil - } - var err error - pl.CreateEvent, err = intent.StateStore.GetCreate(ctx, roomID) - if err != nil { - return fmt.Errorf("failed to get create event from cache: %w", err) - } else if pl.CreateEvent != nil { - return nil - } - pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") - if err != nil { - return fmt.Errorf("failed to get create event from server: %w", err) - } - return nil -} - func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID) if err != nil { @@ -407,12 +371,6 @@ func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) - if err != nil { - return - } - } - if pl.CreateEvent == nil { - pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") } return } @@ -427,7 +385,8 @@ func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, us return nil, err } - if pl.EnsureUserLevelAs(intent.UserID, userID, level) { + if pl.GetUserLevel(userID) != level { + pl.SetUserLevel(userID, level) return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl) } return nil, nil @@ -516,7 +475,7 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU // No need to update return nil } - if !avatarURL.IsEmpty() && !intent.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { + if !avatarURL.IsEmpty() { // Some homeservers require the avatar to be downloaded before setting it resp, _ := intent.Download(ctx, avatarURL) if resp != nil { diff --git a/appservice/ping.go b/appservice/ping.go deleted file mode 100644 index 774ec423..00000000 --- a/appservice/ping.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package appservice - -import ( - "context" - "encoding/json" - "errors" - "os" - "strings" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" -) - -func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context) { - var pingResp *mautrix.RespAppservicePing - var txnID string - var retryCount int - var err error - const maxRetries = 6 - for { - txnID = intent.TxnID() - pingResp, err = intent.AppservicePing(ctx, intent.as.Registration.ID, txnID) - if err == nil { - break - } - var httpErr mautrix.HTTPError - var pingErrBody string - if errors.As(err, &httpErr) && httpErr.RespError != nil { - if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { - pingErrBody = strings.TrimSpace(val) - } - } - outOfRetries := retryCount >= maxRetries - level := zerolog.ErrorLevel - if outOfRetries { - level = zerolog.FatalLevel - } - evt := zerolog.Ctx(ctx).WithLevel(level).Err(err).Str("txn_id", txnID) - if pingErrBody != "" { - bodyBytes := []byte(pingErrBody) - if json.Valid(bodyBytes) { - evt.RawJSON("body", bodyBytes) - } else { - evt.Str("body", pingErrBody) - } - } - if outOfRetries { - evt.Msg("Homeserver -> appservice connection is not working") - zerolog.Ctx(ctx).Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") - os.Exit(13) - } - evt.Msg("Homeserver -> appservice connection is not working, retrying in 5 seconds...") - time.Sleep(5 * time.Second) - retryCount++ - } - zerolog.Ctx(ctx).Debug(). - Str("txn_id", txnID). - Int64("duration_ms", pingResp.DurationMS). - Msg("Homeserver -> appservice connection works") -} diff --git a/appservice/protocol.go b/appservice/protocol.go index 7c493bcb..7a9891ef 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2023 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,7 +7,9 @@ package appservice import ( + "encoding/json" "fmt" + "net/http" "strings" "github.com/rs/zerolog" @@ -101,3 +103,50 @@ func (txn *Transaction) ContentString() string { // EventListener is a function that receives events. type EventListener func(evt *event.Event) + +// WriteBlankOK writes a blank OK message as a reply to a HTTP request. +func WriteBlankOK(w http.ResponseWriter) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) +} + +// Respond responds to a HTTP request with a JSON object. +func Respond(w http.ResponseWriter, data interface{}) error { + w.Header().Add("Content-Type", "application/json") + dataStr, err := json.Marshal(data) + if err != nil { + return err + } + _, err = w.Write(dataStr) + return err +} + +// Error represents a Matrix protocol error. +type Error struct { + HTTPStatus int `json:"-"` + ErrorCode ErrorCode `json:"errcode"` + Message string `json:"error"` +} + +func (err Error) Write(w http.ResponseWriter) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(err.HTTPStatus) + _ = Respond(w, &err) +} + +// ErrorCode is the machine-readable code in an Error. +type ErrorCode string + +// Native ErrorCodes +const ( + ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN" + ErrBadJSON ErrorCode = "M_BAD_JSON" + ErrNotJSON ErrorCode = "M_NOT_JSON" + ErrUnknown ErrorCode = "M_UNKNOWN" +) + +// Custom ErrorCodes +const ( + ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID" +) diff --git a/appservice/registration.go b/appservice/registration.go index 54eff716..b11bd84b 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -27,9 +27,7 @@ type Registration struct { Protocols []string `yaml:"protocols,omitempty" json:"protocols,omitempty"` SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"` - EphemeralEvents bool `yaml:"receive_ephemeral,omitempty" json:"receive_ephemeral,omitempty"` - MSC3202 bool `yaml:"org.matrix.msc3202,omitempty" json:"org.matrix.msc3202,omitempty"` - MSC4190 bool `yaml:"io.element.msc4190,omitempty" json:"io.element.msc4190,omitempty"` + EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"` } // CreateRegistration creates a Registration with random appservice and homeserver tokens. diff --git a/appservice/websocket.go b/appservice/websocket.go index ef65e65a..598d70d1 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2023 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -11,26 +11,26 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" - "path" + "path/filepath" "strings" "sync" "sync/atomic" + "time" - "github.com/coder/websocket" + "github.com/gorilla/websocket" "github.com/rs/zerolog" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - - "maunium.net/go/mautrix" ) type WebsocketRequest struct { - ReqID int `json:"id,omitempty"` - Command string `json:"command"` - Data any `json:"data"` + ReqID int `json:"id,omitempty"` + Command string `json:"command"` + Data interface{} `json:"data"` + + Deadline time.Duration `json:"-"` } type WebsocketCommand struct { @@ -41,7 +41,7 @@ type WebsocketCommand struct { Ctx context.Context `json:"-"` } -func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { +func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest { if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" { return nil } @@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if len(errorData) > 2 && jsonErr == nil { + if errorData != nil && len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break @@ -98,8 +98,8 @@ type WebsocketMessage struct { } const ( - WebsocketCloseConnReplaced websocket.StatusCode = 4001 - WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002 + WebsocketCloseConnReplaced = 4001 + WebsocketCloseTxnNotAcknowledged = 4002 ) type MeowWebsocketCloseCode string @@ -133,7 +133,7 @@ func (mwcc MeowWebsocketCloseCode) String() string { } type CloseCommand struct { - Code websocket.StatusCode `json:"-"` + Code int `json:"-"` Command string `json:"command"` Status MeowWebsocketCloseCode `json:"status"` } @@ -143,15 +143,15 @@ func (cc CloseCommand) Error() string { } func parseCloseError(err error) error { - var closeError websocket.CloseError + closeError := &websocket.CloseError{} if !errors.As(err, &closeError) { return err } var closeCommand CloseCommand closeCommand.Code = closeError.Code closeCommand.Command = "disconnect" - if len(closeError.Reason) > 0 { - jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand) + if len(closeError.Text) > 0 { + jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) if jsonErr != nil { return err } @@ -159,7 +159,7 @@ func parseCloseError(err error) error { if len(closeCommand.Status) == 0 { if closeCommand.Code == WebsocketCloseConnReplaced { closeCommand.Status = MeowConnectionReplaced - } else if closeCommand.Code == websocket.StatusServiceRestart { + } else if closeCommand.Code == websocket.CloseServiceRestart { closeCommand.Status = MeowServerShuttingDown } } @@ -170,23 +170,20 @@ func (as *AppService) HasWebsocket() bool { return as.ws != nil } -func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error { +func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error { ws := as.ws if cmd == nil { return nil } else if ws == nil { return ErrWebsocketNotConnected } - wr, err := ws.Writer(ctx, websocket.MessageText) - if err != nil { - return err + as.wsWriteLock.Lock() + defer as.wsWriteLock.Unlock() + if cmd.Deadline == 0 { + cmd.Deadline = 3 * time.Minute } - err = json.NewEncoder(wr).Encode(cmd) - if err != nil { - _ = wr.Close() - return err - } - return wr.Close() + _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline)) + return ws.WriteJSON(cmd) } func (as *AppService) clearWebsocketResponseWaiters() { @@ -223,12 +220,12 @@ func (er *ErrorResponse) Error() string { return fmt.Sprintf("%s: %s", er.Code, er.Message) } -func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error { +func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error { cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1)) respChan := make(chan *WebsocketCommand, 1) as.addWebsocketResponseWaiter(cmd.ReqID, respChan) defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan) - err := as.SendWebsocket(ctx, cmd) + err := as.SendWebsocket(cmd) if err != nil { return err } @@ -257,7 +254,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques } } -func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) { +func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) { zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command") return false, fmt.Errorf("unknown request type") } @@ -281,28 +278,14 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg return true, &WebsocketTransactionResponse{TxnID: msg.TxnID} } -func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) { +func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { defer stopFunc(ErrWebsocketUnknownError) + ctx := context.Background() for { - msgType, reader, err := ws.Reader(ctx) - if err != nil { - as.Log.Debug().Err(err).Msg("Error getting reader from websocket") - stopFunc(parseCloseError(err)) - return - } else if msgType != websocket.MessageText { - as.Log.Debug().Msg("Ignoring non-text message from websocket") - continue - } - data, err := io.ReadAll(reader) - if err != nil { - as.Log.Debug().Err(err).Msg("Error reading data from websocket") - stopFunc(parseCloseError(err)) - return - } var msg WebsocketMessage - err = json.Unmarshal(data, &msg) + err := ws.ReadJSON(&msg) if err != nil { - as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket") + as.Log.Debug().Err(err).Msg("Error reading from websocket") stopFunc(parseCloseError(err)) return } @@ -313,11 +296,11 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error) with = with.Str("transaction_id", msg.TxnID) } log := with.Logger() - ctx := log.WithContext(ctx) + ctx = log.WithContext(ctx) if msg.Command == "" || msg.Command == "transaction" { ok, resp := as.WebsocketTransactionHandler(ctx, msg) go func() { - err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp)) + err := as.SendWebsocket(msg.MakeResponse(ok, resp)) if err != nil { log.Warn().Err(err).Msg("Failed to send response to websocket transaction") } else { @@ -349,7 +332,7 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error) } go func() { okResp, data := handler(msg.WebsocketCommand) - err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data)) + err := as.SendWebsocket(msg.MakeResponse(okResp, data)) if err != nil { log.Error().Err(err).Msg("Failed to send response to websocket command") } else if okResp { @@ -362,7 +345,7 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error) } } -func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error { +func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { var parsed *url.URL if baseURL != "" { var err error @@ -374,29 +357,26 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn copiedURL := *as.hsURLForClient parsed = &copiedURL } - parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") + parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { parsed.Scheme = "wss" } - ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{ - HTTPClient: as.HTTPClient, - HTTPHeader: http.Header{ - "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, - "User-Agent": []string{as.BotClient().UserAgent}, + ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, + "User-Agent": []string{as.BotClient().UserAgent}, - "X-Mautrix-Process-ID": []string{as.ProcessID}, - "X-Mautrix-Websocket-Version": []string{"3"}, - }, + "X-Mautrix-Process-ID": []string{as.ProcessID}, + "X-Mautrix-Websocket-Version": []string{"3"}, }) if resp != nil && resp.StatusCode >= 400 { - var errResp mautrix.RespError + var errResp Error err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode) } else { - return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrCode, resp.StatusCode, errResp.Err) + return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message) } } else if err != nil { return fmt.Errorf("failed to open websocket: %w", err) @@ -419,13 +399,12 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn } }) } - ws.SetReadLimit(50 * 1024 * 1024) as.ws = ws as.StopWebsocket = stopFunc as.PrepareWebsocket() as.Log.Debug().Msg("Appservice transaction websocket opened") - go as.consumeWebsocket(ctx, stopFunc, ws) + go as.consumeWebsocket(stopFunc, ws) var onConnectDone atomic.Bool if onConnect != nil { @@ -447,7 +426,12 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn as.ws = nil } - err = ws.Close(websocket.StatusGoingAway, "") + _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second)) + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) + if err != nil && !errors.Is(err, websocket.ErrCloseSent) { + as.Log.Warn().Err(err).Msg("Error writing close message to websocket") + } + err = ws.Close() if err != nil { as.Log.Warn().Err(err).Msg("Error closing websocket") } diff --git a/bridge/bridge.go b/bridge/bridge.go new file mode 100644 index 00000000..17a4a30c --- /dev/null +++ b/bridge/bridge.go @@ -0,0 +1,936 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridge + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "os/signal" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/lib/pq" + "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/configupgrade" + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exzerolog" + "gopkg.in/yaml.v3" + flag "maunium.net/go/mauflag" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/sqlstatestore" +) + +var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() +var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() +var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() +var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() +var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() +var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() +var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() +var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() +var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() +var wantHelp, _ = flag.MakeHelpFlag() + +var _ appservice.StateStore = (*sqlstatestore.SQLStateStore)(nil) + +type Portal interface { + IsEncrypted() bool + IsPrivateChat() bool + MarkEncrypted() + MainIntent() *appservice.IntentAPI + + ReceiveMatrixEvent(user User, evt *event.Event) + UpdateBridgeInfo(ctx context.Context) +} + +type MembershipHandlingPortal interface { + Portal + HandleMatrixLeave(sender User, evt *event.Event) + HandleMatrixKick(sender User, ghost Ghost, evt *event.Event) + HandleMatrixInvite(sender User, ghost Ghost, evt *event.Event) +} + +type ReadReceiptHandlingPortal interface { + Portal + HandleMatrixReadReceipt(sender User, eventID id.EventID, receipt event.ReadReceipt) +} + +type TypingPortal interface { + Portal + HandleMatrixTyping(userIDs []id.UserID) +} + +type MetaHandlingPortal interface { + Portal + HandleMatrixMeta(sender User, evt *event.Event) +} + +type DisappearingPortal interface { + Portal + ScheduleDisappearing() +} + +type PowerLevelHandlingPortal interface { + Portal + HandleMatrixPowerLevels(sender User, evt *event.Event) +} + +type JoinRuleHandlingPortal interface { + Portal + HandleMatrixJoinRule(sender User, evt *event.Event) +} + +type BanHandlingPortal interface { + Portal + HandleMatrixBan(sender User, ghost Ghost, evt *event.Event) + HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event) +} + +type KnockHandlingPortal interface { + Portal + HandleMatrixKnock(sender User, evt *event.Event) + HandleMatrixRetractKnock(sender User, evt *event.Event) + HandleMatrixAcceptKnock(sender User, ghost Ghost, evt *event.Event) + HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event) +} + +type InviteHandlingPortal interface { + Portal + HandleMatrixAcceptInvite(sender User, evt *event.Event) + HandleMatrixRejectInvite(sender User, evt *event.Event) + HandleMatrixRetractInvite(sender User, ghost Ghost, evt *event.Event) +} + +type User interface { + GetPermissionLevel() bridgeconfig.PermissionLevel + IsLoggedIn() bool + GetManagementRoomID() id.RoomID + SetManagementRoom(id.RoomID) + GetMXID() id.UserID + GetIDoublePuppet() DoublePuppet + GetIGhost() Ghost +} + +type DoublePuppet interface { + CustomIntent() *appservice.IntentAPI + SwitchCustomMXID(accessToken string, userID id.UserID) error + ClearCustomMXID() +} + +type Ghost interface { + DoublePuppet + DefaultIntent() *appservice.IntentAPI + GetMXID() id.UserID +} + +type GhostWithProfile interface { + Ghost + GetDisplayname() string + GetAvatarURL() id.ContentURI +} + +type ChildOverride interface { + GetExampleConfig() string + GetConfigPtr() interface{} + + Init() + Start() + Stop() + + GetIPortal(id.RoomID) Portal + GetAllIPortals() []Portal + GetIUser(id id.UserID, create bool) User + IsGhost(id.UserID) bool + GetIGhost(id.UserID) Ghost + CreatePrivatePortal(id.RoomID, User, Ghost) +} + +type ConfigValidatingBridge interface { + ChildOverride + ValidateConfig() error +} + +type FlagHandlingBridge interface { + ChildOverride + HandleFlags() bool +} + +type PreInitableBridge interface { + ChildOverride + PreInit() +} + +type WebsocketStartingBridge interface { + ChildOverride + OnWebsocketConnect() +} + +type CSFeatureRequirer interface { + CheckFeatures(versions *mautrix.RespVersions) (string, bool) +} + +type Bridge struct { + Name string + URL string + Description string + Version string + ProtocolName string + BeeperServiceName string + BeeperNetworkName string + + AdditionalShortFlags string + AdditionalLongFlags string + + VersionDesc string + LinkifiedVersion string + BuildTime string + commit string + baseVersion string + + PublicHSAddress *url.URL + + DoublePuppet *doublePuppetUtil + + AS *appservice.AppService + EventProcessor *appservice.EventProcessor + CommandProcessor CommandProcessor + MatrixHandler *MatrixHandler + Bot *appservice.IntentAPI + Config bridgeconfig.BaseConfig + ConfigPath string + RegistrationPath string + SaveConfig bool + ConfigUpgrader configupgrade.BaseUpgrader + DB *dbutil.Database + StateStore *sqlstatestore.SQLStateStore + Crypto Crypto + CryptoPickleKey string + + ZLog *zerolog.Logger + + MediaConfig mautrix.RespMediaConfig + SpecVersions mautrix.RespVersions + + Child ChildOverride + + manualStop chan int + Stopping bool + + latestState *status.BridgeState + + Websocket bool + wsStopPinger chan struct{} + wsStarted chan struct{} + wsStopped chan struct{} + wsShortCircuitReconnectBackoff chan struct{} + wsStartupWait *sync.WaitGroup +} + +type Crypto interface { + HandleMemberEvent(context.Context, *event.Event) + Decrypt(context.Context, *event.Event) (*event.Event, error) + Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + ResetSession(context.Context, id.RoomID) + Init(ctx context.Context) error + Start() + Stop() + Reset(ctx context.Context, startAfterReset bool) + Client() *mautrix.Client + ShareKeys(context.Context) error +} + +func (br *Bridge) GenerateRegistration() { + if !br.SaveConfig { + // We need to save the generated as_token and hs_token in the config + _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") + os.Exit(5) + } else if br.Config.Homeserver.Domain == "example.com" { + _, _ = fmt.Fprintln(os.Stderr, "Homeserver domain is not set") + os.Exit(20) + } + reg := br.Config.GenerateRegistration() + err := reg.Save(br.RegistrationPath) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) + os.Exit(21) + } + + updateTokens := func(helper configupgrade.Helper) { + helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") + helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") + } + _, _, err = configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(updateTokens)) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) + os.Exit(22) + } + fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.") + os.Exit(0) +} + +func (br *Bridge) InitVersion(tag, commit, buildTime string) { + br.baseVersion = br.Version + if len(tag) > 0 && tag[0] == 'v' { + tag = tag[1:] + } + if tag != br.Version { + suffix := "" + if !strings.HasSuffix(br.Version, "+dev") { + suffix = "+dev" + } + if len(commit) > 8 { + br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) + } else { + br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) + } + } + + br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) + if tag == br.Version { + br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) + } else if len(commit) > 8 { + br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) + } + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) + br.VersionDesc = fmt.Sprintf("%s %s (%s with %s)", br.Name, br.Version, buildTime, runtime.Version()) + br.commit = commit + br.BuildTime = buildTime +} + +var MinSpecVersion = mautrix.SpecV14 + +func (br *Bridge) logInitialRequestError(err error, defaultMessage string) { + if errors.Is(err, mautrix.MUnknownToken) { + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") + } else if errors.Is(err, mautrix.MExclusive) { + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") + } else { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg(defaultMessage) + } +} + +func (br *Bridge) ensureConnection(ctx context.Context) { + for { + versions, err := br.Bot.Versions(ctx) + if err != nil { + if errors.Is(err, mautrix.MForbidden) { + br.ZLog.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") + err = br.Bot.EnsureRegistered(ctx) + if err != nil { + br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") + os.Exit(16) + } + } else { + br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") + time.Sleep(10 * time.Second) + } + } else { + br.SpecVersions = *versions + *br.AS.SpecVersions = *versions + break + } + } + + unsupportedServerLogLevel := zerolog.FatalLevel + if *ignoreUnsupportedServer { + unsupportedServerLogLevel = zerolog.ErrorLevel + } + if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") + os.Exit(18) + } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { + br.ZLog.WithLevel(unsupportedServerLogLevel). + Stringer("server_supports", br.SpecVersions.GetLatest()). + Stringer("bridge_requires", MinSpecVersion). + Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") + if !*ignoreUnsupportedServer { + os.Exit(18) + } + } else if fr, ok := br.Child.(CSFeatureRequirer); ok { + if msg, hasFeatures := fr.CheckFeatures(&br.SpecVersions); !hasFeatures { + br.ZLog.WithLevel(unsupportedServerLogLevel).Msg(msg) + if !*ignoreUnsupportedServer { + os.Exit(18) + } + } + } + + resp, err := br.Bot.Whoami(ctx) + if err != nil { + br.logInitialRequestError(err, "/whoami request failed with unknown error") + os.Exit(16) + } else if resp.UserID != br.Bot.UserID { + br.ZLog.WithLevel(zerolog.FatalLevel). + Stringer("got_user_id", resp.UserID). + Stringer("expected_user_id", br.Bot.UserID). + Msg("Unexpected user ID in whoami call") + os.Exit(17) + } + + if br.Websocket { + br.ZLog.Debug().Msg("Websocket mode: no need to check status of homeserver -> bridge connection") + return + } else if !br.SpecVersions.Supports(mautrix.FeatureAppservicePing) { + br.ZLog.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") + return + } + var pingResp *mautrix.RespAppservicePing + var txnID string + var retryCount int + const maxRetries = 6 + for { + txnID = br.Bot.TxnID() + pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) + if err == nil { + break + } + var httpErr mautrix.HTTPError + var pingErrBody string + if errors.As(err, &httpErr) && httpErr.RespError != nil { + if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { + pingErrBody = strings.TrimSpace(val) + } + } + outOfRetries := retryCount >= maxRetries + level := zerolog.ErrorLevel + if outOfRetries { + level = zerolog.FatalLevel + } + evt := br.ZLog.WithLevel(level).Err(err).Str("txn_id", txnID) + if pingErrBody != "" { + bodyBytes := []byte(pingErrBody) + if json.Valid(bodyBytes) { + evt.RawJSON("body", bodyBytes) + } else { + evt.Str("body", pingErrBody) + } + } + if outOfRetries { + evt.Msg("Homeserver -> bridge connection is not working") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") + os.Exit(13) + } + evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") + time.Sleep(5 * time.Second) + retryCount++ + } + br.ZLog.Debug(). + Str("txn_id", txnID). + Int64("duration_ms", pingResp.DurationMS). + Msg("Homeserver -> bridge connection works") +} + +func (br *Bridge) fetchMediaConfig(ctx context.Context) { + cfg, err := br.Bot.GetMediaConfig(ctx) + if err != nil { + br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") + } else { + if cfg.UploadSize == 0 { + cfg.UploadSize = 50 * 1024 * 1024 + } + br.MediaConfig = *cfg + } +} + +func (br *Bridge) UpdateBotProfile(ctx context.Context) { + br.ZLog.Debug().Msg("Updating bot profile") + botConfig := &br.Config.AppService.Bot + + var err error + var mxc id.ContentURI + if botConfig.Avatar == "remove" { + err = br.Bot.SetAvatarURL(ctx, mxc) + } else if !botConfig.ParsedAvatar.IsEmpty() { + err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) + } + if err != nil { + br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar") + } + + if botConfig.Displayname == "remove" { + err = br.Bot.SetDisplayName(ctx, "") + } else if len(botConfig.Displayname) > 0 { + err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) + } + if err != nil { + br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname") + } + + if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" { + br.ZLog.Debug().Msg("Setting contact info on the appservice bot") + br.Bot.BeeperUpdateProfile(ctx, map[string]any{ + "com.beeper.bridge.service": br.BeeperServiceName, + "com.beeper.bridge.network": br.BeeperNetworkName, + "com.beeper.bridge.is_bridge_bot": true, + }) + } +} + +func (br *Bridge) loadConfig() { + configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, br.ConfigUpgrader) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) + if configData == nil { + os.Exit(10) + } + } + + target := br.Child.GetConfigPtr() + if !upgraded { + // Fallback: if config upgrading failed, load example config for base values + err = yaml.Unmarshal([]byte(br.Child.GetExampleConfig()), &target) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to unmarshal example config:", err) + os.Exit(10) + } + } + err = yaml.Unmarshal(configData, target) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) + os.Exit(10) + } +} + +func (br *Bridge) validateConfig() error { + switch { + case br.Config.Homeserver.Address == "https://matrix.example.com": + return errors.New("homeserver.address not configured") + case br.Config.Homeserver.Domain == "example.com": + return errors.New("homeserver.domain not configured") + case !bridgeconfig.AllowedHomeserverSoftware[br.Config.Homeserver.Software]: + return errors.New("invalid value for homeserver.software (use `standard` if you don't know what the field is for)") + case br.Config.AppService.ASToken == "This value is generated when generating the registration": + return errors.New("appservice.as_token not configured. Did you forget to generate the registration? ") + case br.Config.AppService.HSToken == "This value is generated when generating the registration": + return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") + case br.Config.AppService.Database.URI == "postgres://user:password@host/database?sslmode=disable": + return errors.New("appservice.database not configured") + default: + err := br.Config.Bridge.Validate() + if err != nil { + return err + } + validator, ok := br.Child.(ConfigValidatingBridge) + if ok { + return validator.ValidateConfig() + } + return nil + } +} + +func (br *Bridge) getProfile(userID id.UserID, roomID id.RoomID) *event.MemberEventContent { + ghost := br.Child.GetIGhost(userID) + if ghost == nil { + return nil + } + profilefulGhost, ok := ghost.(GhostWithProfile) + if ok { + return &event.MemberEventContent{ + Displayname: profilefulGhost.GetDisplayname(), + AvatarURL: profilefulGhost.GetAvatarURL().CUString(), + } + } + return nil +} + +func (br *Bridge) init() { + pib, ok := br.Child.(PreInitableBridge) + if ok { + pib.PreInit() + } + + var err error + + br.MediaConfig.UploadSize = 50 * 1024 * 1024 + + br.ZLog, err = br.Config.Logging.Compile() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) + os.Exit(12) + } + exzerolog.SetupDefaults(br.ZLog) + + br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} + + err = br.validateConfig() + if err != nil { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") + os.Exit(11) + } + + br.ZLog.Info(). + Str("name", br.Name). + Str("version", br.Version). + Str("built_at", br.BuildTime). + Str("go_version", runtime.Version()). + Msg("Initializing bridge") + + br.ZLog.Debug().Msg("Initializing database connection") + dbConfig := br.Config.AppService.Database + if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { + var fixedExampleURI string + if !strings.HasPrefix(dbConfig.URI, "file:") { + fixedExampleURI = fmt.Sprintf("file:%s?_txlock=immediate", dbConfig.URI) + } else if !strings.ContainsRune(dbConfig.URI, '?') { + fixedExampleURI = fmt.Sprintf("%s?_txlock=immediate", dbConfig.URI) + } else { + fixedExampleURI = fmt.Sprintf("%s&_txlock=immediate", dbConfig.URI) + } + br.ZLog.Warn(). + Str("fixed_uri_example", fixedExampleURI). + Msg("Using SQLite without _txlock=immediate is not recommended") + } + br.DB, err = dbutil.NewFromConfig(br.Name, dbConfig, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "main").Logger())) + if err != nil { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") + if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { + os.Exit(18) + } + os.Exit(14) + } + br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase + br.DB.IgnoreForeignTables = *ignoreForeignTables + + br.ZLog.Debug().Msg("Initializing state store") + br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger()), true) + + br.AS, err = appservice.CreateFull(appservice.CreateOpts{ + Registration: br.Config.AppService.GetRegistration(), + HomeserverDomain: br.Config.Homeserver.Domain, + HomeserverURL: br.Config.Homeserver.Address, + HostConfig: appservice.HostConfig{ + Hostname: br.Config.AppService.Hostname, + Port: br.Config.AppService.Port, + }, + StateStore: br.StateStore, + }) + if err != nil { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). + Msg("Failed to initialize appservice") + os.Exit(15) + } + br.AS.Log = *br.ZLog + br.AS.DoublePuppetValue = br.Name + br.AS.GetProfile = br.getProfile + br.Bot = br.AS.BotIntent() + + br.ZLog.Debug().Msg("Initializing Matrix event processor") + br.EventProcessor = appservice.NewEventProcessor(br.AS) + if !br.Config.AppService.AsyncTransactions { + br.EventProcessor.ExecMode = appservice.Sync + } + br.ZLog.Debug().Msg("Initializing Matrix event handler") + br.MatrixHandler = NewMatrixHandler(br) + + br.Crypto = NewCryptoHelper(br) + + hsURL := br.Config.Homeserver.Address + if br.Config.Homeserver.PublicAddress != "" { + hsURL = br.Config.Homeserver.PublicAddress + } + br.PublicHSAddress, err = url.Parse(hsURL) + if err != nil { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). + Str("input", hsURL). + Msg("Failed to parse public homeserver URL") + os.Exit(15) + } + + br.Child.Init() +} + +type zerologPQError pq.Error + +func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { + maybeStr := func(field, value string) { + if value != "" { + evt.Str(field, value) + } + } + maybeStr("severity", zpe.Severity) + if name := zpe.Code.Name(); name != "" { + evt.Str("code", name) + } else if zpe.Code != "" { + evt.Str("code", string(zpe.Code)) + } + //maybeStr("message", zpe.Message) + maybeStr("detail", zpe.Detail) + maybeStr("hint", zpe.Hint) + maybeStr("position", zpe.Position) + maybeStr("internal_position", zpe.InternalPosition) + maybeStr("internal_query", zpe.InternalQuery) + maybeStr("where", zpe.Where) + maybeStr("schema", zpe.Schema) + maybeStr("table", zpe.Table) + maybeStr("column", zpe.Column) + maybeStr("data_type_name", zpe.DataTypeName) + maybeStr("constraint", zpe.Constraint) + maybeStr("file", zpe.File) + maybeStr("line", zpe.Line) + maybeStr("routine", zpe.Routine) +} + +func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { + logEvt := br.ZLog.WithLevel(zerolog.FatalLevel). + Err(err). + Str("db_section", name) + var errWithLine *dbutil.PQErrorWithLine + if errors.As(err, &errWithLine) { + logEvt.Str("sql_line", errWithLine.Line) + } + var pqe *pq.Error + if errors.As(err, &pqe) { + logEvt.Object("pq_error", (*zerologPQError)(pqe)) + } + logEvt.Msg("Failed to initialize database") + if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { + os.Exit(18) + } else if errors.Is(err, dbutil.ErrForeignTables) { + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") + } else if errors.Is(err, dbutil.ErrNotOwned) { + br.ZLog.Info().Msg("Sharing the same database with different programs is not supported") + } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { + br.ZLog.Info().Msg("Downgrading the bridge is not supported") + } + os.Exit(15) +} + +func (br *Bridge) WaitWebsocketConnected() { + if br.wsStartupWait != nil { + br.wsStartupWait.Wait() + } +} + +func (br *Bridge) start() { + br.ZLog.Debug().Msg("Running database upgrades") + err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO())) + if err != nil { + br.LogDBUpgradeErrorAndExit("main", err) + } else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil { + br.LogDBUpgradeErrorAndExit("matrix_state", err) + } + + if br.Config.Homeserver.Websocket || len(br.Config.Homeserver.WSProxy) > 0 { + br.Websocket = true + br.ZLog.Debug().Msg("Starting application service websocket") + var wg sync.WaitGroup + wg.Add(1) + br.wsStartupWait = &wg + br.wsShortCircuitReconnectBackoff = make(chan struct{}) + go br.startWebsocket(&wg) + } else if br.AS.Host.IsConfigured() { + br.ZLog.Debug().Msg("Starting application service HTTP server") + go br.AS.Start() + } else { + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("Neither appservice HTTP listener nor websocket is enabled") + os.Exit(23) + } + br.ZLog.Debug().Msg("Checking connection to homeserver") + + ctx := br.ZLog.WithContext(context.Background()) + br.ensureConnection(ctx) + go br.fetchMediaConfig(ctx) + + if br.Crypto != nil { + err = br.Crypto.Init(ctx) + if err != nil { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption") + os.Exit(19) + } + } + + br.ZLog.Debug().Msg("Starting event processor") + br.EventProcessor.Start(ctx) + + go br.UpdateBotProfile(ctx) + if br.Crypto != nil { + go br.Crypto.Start() + } + + br.Child.Start() + br.WaitWebsocketConnected() + br.AS.Ready = true + + if br.Config.Bridge.GetResendBridgeInfo() { + go br.ResendBridgeInfo() + } + if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { + br.wsStopPinger = make(chan struct{}, 1) + go br.websocketServerPinger() + } +} + +func (br *Bridge) ResendBridgeInfo() { + if !br.SaveConfig { + br.ZLog.Warn().Msg("Not setting resend_bridge_info to false in config due to --no-update flag") + } else { + _, _, err := configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(func(helper configupgrade.Helper) { + helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") + })) + if err != nil { + br.ZLog.Err(err).Msg("Failed to save config after setting resend_bridge_info to false") + } + } + br.ZLog.Info().Msg("Re-sending bridge info state event to all portals") + for _, portal := range br.Child.GetAllIPortals() { + portal.UpdateBridgeInfo(context.TODO()) + } + br.ZLog.Info().Msg("Finished re-sending bridge info state events") +} + +func sendStopSignal(ch chan struct{}) { + if ch != nil { + select { + case ch <- struct{}{}: + default: + } + } +} + +func (br *Bridge) stop() { + br.Stopping = true + if br.Crypto != nil { + br.Crypto.Stop() + } + waitForWS := false + if br.AS.StopWebsocket != nil { + br.ZLog.Debug().Msg("Stopping application service websocket") + br.AS.StopWebsocket(appservice.ErrWebsocketManualStop) + waitForWS = true + } + br.AS.Stop() + sendStopSignal(br.wsStopPinger) + sendStopSignal(br.wsShortCircuitReconnectBackoff) + br.EventProcessor.Stop() + br.Child.Stop() + err := br.DB.Close() + if err != nil { + br.ZLog.Warn().Err(err).Msg("Error closing database") + } + if waitForWS { + select { + case <-br.wsStopped: + case <-time.After(4 * time.Second): + br.ZLog.Warn().Msg("Timed out waiting for websocket to close") + } + } +} + +func (br *Bridge) ManualStop(exitCode int) { + if br.manualStop != nil { + br.manualStop <- exitCode + } else { + os.Exit(exitCode) + } +} + +type VersionJSONOutput struct { + Name string + URL string + + Version string + IsRelease bool + Commit string + FormattedVersion string + BuildTime string + + OS string + Arch string + + Mautrix struct { + Version string + Commit string + } +} + +func (br *Bridge) Main() { + flag.SetHelpTitles( + fmt.Sprintf("%s - %s", br.Name, br.Description), + fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) + err := flag.Parse() + br.ConfigPath = *configPath + br.RegistrationPath = *registrationPath + br.SaveConfig = !*dontSaveConfig + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + flag.PrintHelp() + os.Exit(1) + } else if *wantHelp { + flag.PrintHelp() + os.Exit(0) + } else if *version { + fmt.Println(br.VersionDesc) + return + } else if *versionJSON { + output := VersionJSONOutput{ + URL: br.URL, + Name: br.Name, + + Version: br.baseVersion, + IsRelease: br.Version == br.baseVersion, + Commit: br.commit, + FormattedVersion: br.Version, + BuildTime: br.BuildTime, + + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + output.Mautrix.Commit = mautrix.Commit + output.Mautrix.Version = mautrix.Version + _ = json.NewEncoder(os.Stdout).Encode(output) + return + } else if flagHandler, ok := br.Child.(FlagHandlingBridge); ok && flagHandler.HandleFlags() { + return + } + + br.loadConfig() + + if *generateRegistration { + br.GenerateRegistration() + return + } + + br.manualStop = make(chan int, 1) + br.init() + br.ZLog.Info().Msg("Bridge initialization complete, starting...") + br.start() + br.ZLog.Info().Msg("Bridge started!") + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + var exitCode int + select { + case <-c: + br.ZLog.Info().Msg("Interrupt received, stopping...") + case exitCode = <-br.manualStop: + br.ZLog.Info().Int("exit_code", exitCode).Msg("Manual stop requested") + } + + br.stop() + br.ZLog.Info().Msg("Bridge stopped.") + os.Exit(exitCode) +} diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go new file mode 100644 index 00000000..dfb6b7e5 --- /dev/null +++ b/bridge/bridgeconfig/config.go @@ -0,0 +1,337 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/rs/zerolog" + up "go.mau.fi/util/configupgrade" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/random" + "go.mau.fi/zeroconfig" + "gopkg.in/yaml.v3" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type HomeserverSoftware string + +const ( + SoftwareStandard HomeserverSoftware = "standard" + SoftwareAsmux HomeserverSoftware = "asmux" + SoftwareHungry HomeserverSoftware = "hungry" +) + +var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{ + SoftwareStandard: true, + SoftwareAsmux: true, + SoftwareHungry: true, +} + +type HomeserverConfig struct { + Address string `yaml:"address"` + Domain string `yaml:"domain"` + AsyncMedia bool `yaml:"async_media"` + + PublicAddress string `yaml:"public_address,omitempty"` + + Software HomeserverSoftware `yaml:"software"` + + StatusEndpoint string `yaml:"status_endpoint"` + MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` + + Websocket bool `yaml:"websocket"` + WSProxy string `yaml:"websocket_proxy"` + WSPingInterval int `yaml:"ping_interval_seconds"` +} + +type AppserviceConfig struct { + Address string `yaml:"address"` + Hostname string `yaml:"hostname"` + Port uint16 `yaml:"port"` + + Database dbutil.Config `yaml:"database"` + + ID string `yaml:"id"` + Bot BotUserConfig `yaml:"bot"` + + ASToken string `yaml:"as_token"` + HSToken string `yaml:"hs_token"` + + EphemeralEvents bool `yaml:"ephemeral_events"` + AsyncTransactions bool `yaml:"async_transactions"` +} + +func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp { + usernamePlaceholder := strings.ToLower(random.String(16)) + usernameTemplate := fmt.Sprintf("@%s:%s", + config.Bridge.FormatUsername(usernamePlaceholder), + config.Homeserver.Domain) + usernameTemplate = regexp.QuoteMeta(usernameTemplate) + usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1) + usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) + return regexp.MustCompile(usernameTemplate) +} + +// GenerateRegistration generates a registration file for the homeserver. +func (config *BaseConfig) GenerateRegistration() *appservice.Registration { + registration := appservice.CreateRegistration() + config.AppService.HSToken = registration.ServerToken + config.AppService.ASToken = registration.AppToken + config.AppService.copyToRegistration(registration) + + registration.SenderLocalpart = random.String(32) + botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", + regexp.QuoteMeta(config.AppService.Bot.Username), + regexp.QuoteMeta(config.Homeserver.Domain))) + registration.Namespaces.UserIDs.Register(botRegex, true) + registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true) + + return registration +} + +func (config *BaseConfig) MakeAppService() *appservice.AppService { + as := appservice.Create() + as.HomeserverDomain = config.Homeserver.Domain + _ = as.SetHomeserverURL(config.Homeserver.Address) + as.Host.Hostname = config.AppService.Hostname + as.Host.Port = config.AppService.Port + as.Registration = config.AppService.GetRegistration() + return as +} + +// GetRegistration copies the data from the bridge config into an *appservice.Registration struct. +// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver. +func (asc *AppserviceConfig) GetRegistration() *appservice.Registration { + reg := &appservice.Registration{} + asc.copyToRegistration(reg) + reg.SenderLocalpart = asc.Bot.Username + reg.ServerToken = asc.HSToken + reg.AppToken = asc.ASToken + return reg +} + +func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) { + registration.ID = asc.ID + registration.URL = asc.Address + falseVal := false + registration.RateLimited = &falseVal + registration.EphemeralEvents = asc.EphemeralEvents + registration.SoruEphemeralEvents = asc.EphemeralEvents +} + +type BotUserConfig struct { + Username string `yaml:"username"` + Displayname string `yaml:"displayname"` + Avatar string `yaml:"avatar"` + + ParsedAvatar id.ContentURI `yaml:"-"` +} + +type serializableBUC BotUserConfig + +func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + var sbuc serializableBUC + err := unmarshal(&sbuc) + if err != nil { + return err + } + *buc = (BotUserConfig)(sbuc) + if buc.Avatar != "" && buc.Avatar != "remove" { + buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar) + if err != nil { + return fmt.Errorf("%w in bot avatar", err) + } + } + return nil +} + +type BridgeConfig interface { + FormatUsername(username string) string + GetEncryptionConfig() EncryptionConfig + GetCommandPrefix() string + GetManagementRoomTexts() ManagementRoomTexts + GetDoublePuppetConfig() DoublePuppetConfig + GetResendBridgeInfo() bool + EnableMessageStatusEvents() bool + EnableMessageErrorNotices() bool + Validate() error +} + +type DoublePuppetConfig struct { + ServerMap map[string]string `yaml:"double_puppet_server_map"` + AllowDiscovery bool `yaml:"double_puppet_allow_discovery"` + SharedSecretMap map[string]string `yaml:"login_shared_secret_map"` +} + +type EncryptionConfig struct { + Allow bool `yaml:"allow"` + Default bool `yaml:"default"` + Require bool `yaml:"require"` + Appservice bool `yaml:"appservice"` + + PlaintextMentions bool `yaml:"plaintext_mentions"` + + DeleteKeys struct { + DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` + DontStoreOutbound bool `yaml:"dont_store_outbound"` + RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` + DeleteFullyUsedOnDecrypt bool `yaml:"delete_fully_used_on_decrypt"` + DeletePrevOnNewSession bool `yaml:"delete_prev_on_new_session"` + DeleteOnDeviceDelete bool `yaml:"delete_on_device_delete"` + PeriodicallyDeleteExpired bool `yaml:"periodically_delete_expired"` + DeleteOutdatedInbound bool `yaml:"delete_outdated_inbound"` + } `yaml:"delete_keys"` + + VerificationLevels struct { + Receive id.TrustState `yaml:"receive"` + Send id.TrustState `yaml:"send"` + Share id.TrustState `yaml:"share"` + } `yaml:"verification_levels"` + AllowKeySharing bool `yaml:"allow_key_sharing"` + + Rotation struct { + EnableCustom bool `yaml:"enable_custom"` + Milliseconds int64 `yaml:"milliseconds"` + Messages int `yaml:"messages"` + + DisableDeviceChangeKeyRotation bool `yaml:"disable_device_change_key_rotation"` + } `yaml:"rotation"` +} + +type ManagementRoomTexts struct { + Welcome string `yaml:"welcome"` + WelcomeConnected string `yaml:"welcome_connected"` + WelcomeUnconnected string `yaml:"welcome_unconnected"` + AdditionalHelp string `yaml:"additional_help"` +} + +type BaseConfig struct { + Homeserver HomeserverConfig `yaml:"homeserver"` + AppService AppserviceConfig `yaml:"appservice"` + Bridge BridgeConfig `yaml:"-"` + Logging zeroconfig.Config `yaml:"logging"` +} + +func doUpgrade(helper up.Helper) { + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { + helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software") + } else { + helper.Copy(up.Str, "homeserver", "software") + } + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { + helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type") + } else { + helper.Copy(up.Str, "appservice", "database", "type") + } + helper.Copy(up.Str, "appservice", "database", "uri") + helper.Copy(up.Int, "appservice", "database", "max_open_conns") + helper.Copy(up.Int, "appservice", "database", "max_idle_conns") + helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") + helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") + helper.Copy(up.Str, "appservice", "id") + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + + if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") + migrateLegacyLogConfig(helper) + } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") + // TODO implement? + //migratePythonLogConfig(helper) + } else { + helper.Copy(up.Map, "logging") + } +} + +type legacyLogConfig struct { + Directory string `yaml:"directory"` + FileNameFormat string `yaml:"file_name_format"` + FileDateFormat string `yaml:"file_date_format"` + FileMode uint32 `yaml:"file_mode"` + TimestampFormat string `yaml:"timestamp_format"` + RawPrintLevel string `yaml:"print_level"` + JSONStdout bool `yaml:"print_json"` + JSONFile bool `yaml:"file_json"` +} + +func migrateLegacyLogConfig(helper up.Helper) { + var llc legacyLogConfig + var newConfig zeroconfig.Config + err := helper.GetBaseNode("logging").Decode(&newConfig) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Base config is corrupted: failed to decode example log config:", err) + return + } else if len(newConfig.Writers) != 2 || newConfig.Writers[0].Type != "stdout" || newConfig.Writers[1].Type != "file" { + _, _ = fmt.Fprintln(os.Stderr, "Base log config is not in expected format") + return + } + err = helper.GetNode("logging").Decode(&llc) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to decode legacy log config:", err) + return + } + if llc.RawPrintLevel != "" { + level, err := zerolog.ParseLevel(llc.RawPrintLevel) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse minimum stdout log level:", err) + } else { + newConfig.Writers[0].MinLevel = &level + } + } + if llc.Directory != "" && llc.FileNameFormat != "" { + if llc.FileNameFormat == "{{.Date}}-{{.Index}}.log" { + llc.FileNameFormat = "bridge.log" + } else { + llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Date}}", "") + llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Index}}", "") + } + newConfig.Writers[1].Filename = filepath.Join(llc.Directory, llc.FileNameFormat) + } else if llc.FileNameFormat == "" { + newConfig.Writers = newConfig.Writers[0:1] + } + if llc.JSONStdout { + newConfig.Writers[0].TimeFormat = "" + newConfig.Writers[0].Format = "json" + } else if llc.TimestampFormat != "" { + newConfig.Writers[0].TimeFormat = llc.TimestampFormat + } + var updatedConfig yaml.Node + err = updatedConfig.Encode(&newConfig) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to encode migrated log config:", err) + return + } + *helper.GetBaseNode("logging").Node = updatedConfig +} + +// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. +var Upgrader = up.SimpleUpgrader(doUpgrade) diff --git a/bridge/bridgeconfig/permissions.go b/bridge/bridgeconfig/permissions.go new file mode 100644 index 00000000..198e140e --- /dev/null +++ b/bridge/bridgeconfig/permissions.go @@ -0,0 +1,71 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "strconv" + "strings" + + "maunium.net/go/mautrix/id" +) + +type PermissionConfig map[string]PermissionLevel + +type PermissionLevel int + +const ( + PermissionLevelBlock PermissionLevel = 0 + PermissionLevelRelay PermissionLevel = 5 + PermissionLevelUser PermissionLevel = 10 + PermissionLevelAdmin PermissionLevel = 100 +) + +var namesToLevels = map[string]PermissionLevel{ + "block": PermissionLevelBlock, + "relay": PermissionLevelRelay, + "user": PermissionLevelUser, + "admin": PermissionLevelAdmin, +} + +func RegisterPermissionLevel(name string, level PermissionLevel) { + namesToLevels[name] = level +} + +func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + rawPC := make(map[string]string) + err := unmarshal(&rawPC) + if err != nil { + return err + } + + if *pc == nil { + *pc = make(map[string]PermissionLevel) + } + for key, value := range rawPC { + level, ok := namesToLevels[strings.ToLower(value)] + if ok { + (*pc)[key] = level + } else if val, err := strconv.Atoi(value); err == nil { + (*pc)[key] = PermissionLevel(val) + } else { + (*pc)[key] = PermissionLevelBlock + } + } + return nil +} + +func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel { + if level, ok := pc[string(userID)]; ok { + return level + } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { + return level + } else if level, ok = pc["*"]; ok { + return level + } else { + return PermissionLevelBlock + } +} diff --git a/bridge/bridgestate.go b/bridge/bridgestate.go new file mode 100644 index 00000000..f9c3a3c6 --- /dev/null +++ b/bridge/bridgestate.go @@ -0,0 +1,156 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridge + +import ( + "context" + "runtime/debug" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/status" +) + +func (br *Bridge) SendBridgeState(ctx context.Context, state *status.BridgeState) error { + if br.Websocket { + // FIXME this doesn't account for multiple users + br.latestState = state + + return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + Command: "bridge_status", + Data: state, + }) + } else if br.Config.Homeserver.StatusEndpoint != "" { + return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) + } else { + return nil + } +} + +func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { + if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { + return + } + + for { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if err := br.SendBridgeState(ctx, &state); err != nil { + br.ZLog.Warn().Err(err).Msg("Failed to update global bridge state") + cancel() + time.Sleep(5 * time.Second) + continue + } else { + br.ZLog.Debug().Interface("bridge_state", state).Msg("Sent new global bridge state") + cancel() + break + } + } +} + +type BridgeStateQueue struct { + prev *status.BridgeState + ch chan status.BridgeState + bridge *Bridge + user status.BridgeStateFiller +} + +func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStateQueue { + if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { + return nil + } + bsq := &BridgeStateQueue{ + ch: make(chan status.BridgeState, 10), + bridge: br, + user: user, + } + go bsq.loop() + return bsq +} + +func (bsq *BridgeStateQueue) loop() { + defer func() { + err := recover() + if err != nil { + bsq.bridge.ZLog.Error(). + Str(zerolog.ErrorStackFieldName, string(debug.Stack())). + Interface(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() + for state := range bsq.ch { + bsq.immediateSendBridgeState(state) + } +} + +func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { + retryIn := 2 + for { + if bsq.prev != nil && bsq.prev.ShouldDeduplicate(&state) { + bsq.bridge.ZLog.Debug(). + Str("state_event", string(state.StateEvent)). + Msg("Not sending bridge state as it's a duplicate") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + err := bsq.bridge.SendBridgeState(ctx, &state) + cancel() + + if err != nil { + bsq.bridge.ZLog.Warn().Err(err). + Int("retry_in_seconds", retryIn). + Msg("Failed to update bridge state") + time.Sleep(time.Duration(retryIn) * time.Second) + retryIn *= 2 + if retryIn > 64 { + retryIn = 64 + } + } else { + bsq.prev = &state + bsq.bridge.ZLog.Debug(). + Interface("bridge_state", state). + Msg("Sent new bridge state") + return + } + } +} + +func (bsq *BridgeStateQueue) Send(state status.BridgeState) { + if bsq == nil { + return + } + + state = state.Fill(bsq.user) + + if len(bsq.ch) >= 8 { + bsq.bridge.ZLog.Warn().Msg("Bridge state queue is nearly full, discarding an item") + select { + case <-bsq.ch: + default: + } + } + select { + case bsq.ch <- state: + default: + bsq.bridge.ZLog.Error().Msg("Bridge state queue is full, dropped new state") + } +} + +func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { + if bsq != nil && bsq.prev != nil { + return *bsq.prev + } + return status.BridgeState{} +} + +func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { + if bsq != nil { + bsq.prev = &prev + } +} diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go new file mode 100644 index 00000000..ff3340e3 --- /dev/null +++ b/bridge/commands/admin.go @@ -0,0 +1,77 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strconv" + + "maunium.net/go/mautrix/id" +) + +var CommandDiscardMegolmSession = &FullHandler{ + Func: func(ce *Event) { + if ce.Bridge.Crypto == nil { + ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") + } else { + ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID) + ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") + } + }, + Name: "discard-megolm-session", + Aliases: []string{"discard-session"}, + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Discard the Megolm session in the room", + }, + RequiresAdmin: true, +} + +func fnSetPowerLevel(ce *Event) { + var level int + var userID id.UserID + var err error + if len(ce.Args) == 1 { + level, err = strconv.Atoi(ce.Args[0]) + if err != nil { + ce.Reply("Invalid power level \"%s\"", ce.Args[0]) + return + } + userID = ce.User.GetMXID() + } else if len(ce.Args) == 2 { + userID = id.UserID(ce.Args[0]) + _, _, err := userID.Parse() + if err != nil { + ce.Reply("Invalid user ID \"%s\"", ce.Args[0]) + return + } + level, err = strconv.Atoi(ce.Args[1]) + if err != nil { + ce.Reply("Invalid power level \"%s\"", ce.Args[1]) + return + } + } else { + ce.Reply("**Usage:** `set-pl [user] `") + return + } + _, err = ce.Portal.MainIntent().SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) + if err != nil { + ce.Reply("Failed to set power levels: %v", err) + } +} + +var CommandSetPowerLevel = &FullHandler{ + Func: fnSetPowerLevel, + Name: "set-pl", + Aliases: []string{"set-power-level"}, + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Change the power level in a portal room.", + Args: "[_user ID_] <_power level_>", + }, + RequiresAdmin: true, + RequiresPortal: true, +} diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go new file mode 100644 index 00000000..3f074951 --- /dev/null +++ b/bridge/commands/doublepuppet.go @@ -0,0 +1,83 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +var CommandLoginMatrix = &FullHandler{ + Func: fnLoginMatrix, + Name: "login-matrix", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Enable double puppeting.", + Args: "<_access token_>", + }, + RequiresLogin: true, +} + +func fnLoginMatrix(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("**Usage:** `login-matrix `") + return + } + puppet := ce.User.GetIDoublePuppet() + if puppet == nil { + puppet = ce.User.GetIGhost() + if puppet == nil { + ce.Reply("Didn't get a ghost :(") + return + } + } + err := puppet.SwitchCustomMXID(ce.Args[0], ce.User.GetMXID()) + if err != nil { + ce.Reply("Failed to enable double puppeting: %v", err) + } else { + ce.Reply("Successfully switched puppet") + } +} + +var CommandPingMatrix = &FullHandler{ + Func: fnPingMatrix, + Name: "ping-matrix", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Ping the Matrix server with the double puppet.", + }, + RequiresLogin: true, +} + +func fnPingMatrix(ce *Event) { + puppet := ce.User.GetIDoublePuppet() + if puppet == nil || puppet.CustomIntent() == nil { + ce.Reply("You are not logged in with your Matrix account.") + return + } + resp, err := puppet.CustomIntent().Whoami(ce.Ctx) + if err != nil { + ce.Reply("Failed to validate Matrix login: %v", err) + } else { + ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) + } +} + +var CommandLogoutMatrix = &FullHandler{ + Func: fnLogoutMatrix, + Name: "logout-matrix", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Disable double puppeting.", + }, + RequiresLogin: true, +} + +func fnLogoutMatrix(ce *Event) { + puppet := ce.User.GetIDoublePuppet() + if puppet == nil || puppet.CustomIntent() == nil { + ce.Reply("You don't have double puppeting enabled.") + return + } + puppet.ClearCustomMXID() + ce.Reply("Successfully disabled double puppeting.") +} diff --git a/bridge/commands/event.go b/bridge/commands/event.go new file mode 100644 index 00000000..42b49b68 --- /dev/null +++ b/bridge/commands/event.go @@ -0,0 +1,95 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +// Event stores all data which might be used to handle commands +type Event struct { + Bot *appservice.IntentAPI + Bridge *bridge.Bridge + Portal bridge.Portal + Processor *Processor + Handler MinimalHandler + RoomID id.RoomID + EventID id.EventID + User bridge.User + Command string + Args []string + RawArgs string + ReplyTo id.EventID + Ctx context.Context + ZLog *zerolog.Logger +} + +// MainIntent returns the intent to use when replying to the command. +// +// It prefers the bridge bot, but falls back to the other user in DMs if the bridge bot is not present. +func (ce *Event) MainIntent() *appservice.IntentAPI { + intent := ce.Bot + if ce.Portal != nil && ce.Portal.IsPrivateChat() && !ce.Portal.IsEncrypted() { + intent = ce.Portal.MainIntent() + } + return intent +} + +// Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. +func (ce *Event) Reply(msg string, args ...interface{}) { + msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.Bridge.GetCommandPrefix()+" ") + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + ce.ReplyAdvanced(msg, true, false) +} + +// ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, +// but doesn't have built-in string formatting. +func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { + content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) + content.MsgType = event.MsgNotice + _, err := ce.MainIntent().SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, content) + if err != nil { + ce.ZLog.Error().Err(err).Msgf("Failed to reply to command") + } +} + +// React sends a reaction to the command. +func (ce *Event) React(key string) { + _, err := ce.MainIntent().SendReaction(ce.Ctx, ce.RoomID, ce.EventID, key) + if err != nil { + ce.ZLog.Error().Err(err).Msgf("Failed to react to command") + } +} + +// Redact redacts the command. +func (ce *Event) Redact(req ...mautrix.ReqRedact) { + _, err := ce.MainIntent().RedactEvent(ce.Ctx, ce.RoomID, ce.EventID, req...) + if err != nil { + ce.ZLog.Error().Err(err).Msgf("Failed to redact command") + } +} + +// MarkRead marks the command event as read. +func (ce *Event) MarkRead() { + err := ce.MainIntent().SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) + if err != nil { + ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read") + } +} diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go new file mode 100644 index 00000000..ab6899c0 --- /dev/null +++ b/bridge/commands/handler.go @@ -0,0 +1,100 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "maunium.net/go/mautrix/bridge" + "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/event" +) + +type MinimalHandler interface { + Run(*Event) +} + +type MinimalHandlerFunc func(*Event) + +func (mhf MinimalHandlerFunc) Run(ce *Event) { + mhf(ce) +} + +type CommandState struct { + Next MinimalHandler + Action string + Meta interface{} +} + +type CommandingUser interface { + bridge.User + GetCommandState() *CommandState + SetCommandState(*CommandState) +} + +type Handler interface { + MinimalHandler + GetName() string +} + +type AliasedHandler interface { + Handler + GetAliases() []string +} + +type FullHandler struct { + Func func(*Event) + + Name string + Aliases []string + Help HelpMeta + + RequiresAdmin bool + RequiresPortal bool + RequiresLogin bool + + RequiresEventLevel event.Type +} + +func (fh *FullHandler) GetHelp() HelpMeta { + fh.Help.Command = fh.Name + return fh.Help +} + +func (fh *FullHandler) GetName() string { + return fh.Name +} + +func (fh *FullHandler) GetAliases() []string { + return fh.Aliases +} + +func (fh *FullHandler) ShowInHelp(ce *Event) bool { + return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin +} + +func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { + levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) + if err != nil { + ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") + ce.Reply("Failed to get room power levels to see if you're allowed to use that command") + return false + } + return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) +} + +func (fh *FullHandler) Run(ce *Event) { + if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { + ce.Reply("That command is limited to bridge administrators.") + } else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { + ce.Reply("That command requires room admin rights.") + } else if fh.RequiresPortal && ce.Portal == nil { + ce.Reply("That command can only be ran in portal rooms.") + } else if fh.RequiresLogin && !ce.User.IsLoggedIn() { + ce.Reply("That command requires you to be logged in.") + } else { + fh.Func(ce) + } +} diff --git a/bridge/commands/help.go b/bridge/commands/help.go new file mode 100644 index 00000000..f4891555 --- /dev/null +++ b/bridge/commands/help.go @@ -0,0 +1,129 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "fmt" + "sort" + "strings" +) + +type HelpfulHandler interface { + Handler + GetHelp() HelpMeta + ShowInHelp(*Event) bool +} + +type HelpSection struct { + Name string + Order int +} + +var ( + // Deprecated: this should be used as a placeholder that needs to be fixed + HelpSectionUnclassified = HelpSection{"Unclassified", -1} + + HelpSectionGeneral = HelpSection{"General", 0} + HelpSectionAuth = HelpSection{"Authentication", 10} + HelpSectionAdmin = HelpSection{"Administration", 50} +) + +type HelpMeta struct { + Command string + Section HelpSection + Description string + Args string +} + +func (hm *HelpMeta) String() string { + if len(hm.Args) == 0 { + return fmt.Sprintf("**%s** - %s", hm.Command, hm.Description) + } + return fmt.Sprintf("**%s** %s - %s", hm.Command, hm.Args, hm.Description) +} + +type helpSectionList []HelpSection + +func (h helpSectionList) Len() int { + return len(h) +} + +func (h helpSectionList) Less(i, j int) bool { + return h[i].Order < h[j].Order +} + +func (h helpSectionList) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +type helpMetaList []HelpMeta + +func (h helpMetaList) Len() int { + return len(h) +} + +func (h helpMetaList) Less(i, j int) bool { + return h[i].Command < h[j].Command +} + +func (h helpMetaList) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +var _ sort.Interface = (helpSectionList)(nil) +var _ sort.Interface = (helpMetaList)(nil) + +func FormatHelp(ce *Event) string { + sections := make(map[HelpSection]helpMetaList) + for _, handler := range ce.Processor.handlers { + helpfulHandler, ok := handler.(HelpfulHandler) + if !ok || !helpfulHandler.ShowInHelp(ce) { + continue + } + help := helpfulHandler.GetHelp() + if help.Description == "" { + continue + } + sections[help.Section] = append(sections[help.Section], help) + } + + sortedSections := make(helpSectionList, 0, len(sections)) + for section := range sections { + sortedSections = append(sortedSections, section) + } + sort.Sort(sortedSections) + + var output strings.Builder + output.Grow(10240) + + var prefixMsg string + if ce.RoomID == ce.User.GetManagementRoomID() { + prefixMsg = "This is your management room: prefixing commands with `%s` is not required." + } else if ce.Portal != nil { + prefixMsg = "**This is a portal room**: you must always prefix commands with `%s`. Management commands will not be bridged." + } else { + prefixMsg = "This is not your management room: prefixing commands with `%s` is required." + } + _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.Bridge.GetCommandPrefix()) + output.WriteByte('\n') + output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") + output.WriteByte('\n') + output.WriteByte('\n') + + for _, section := range sortedSections { + output.WriteString("#### ") + output.WriteString(section.Name) + output.WriteByte('\n') + sort.Sort(sections[section]) + for _, command := range sections[section] { + output.WriteString(command.String()) + output.WriteByte('\n') + } + output.WriteByte('\n') + } + return output.String() +} diff --git a/bridge/commands/meta.go b/bridge/commands/meta.go new file mode 100644 index 00000000..615f6a34 --- /dev/null +++ b/bridge/commands/meta.go @@ -0,0 +1,56 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +var CommandHelp = &FullHandler{ + Func: func(ce *Event) { + ce.Reply(FormatHelp(ce)) + }, + Name: "help", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Show this help message.", + }, +} + +var CommandVersion = &FullHandler{ + Func: func(ce *Event) { + ce.Reply("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, ce.Bridge.BuildTime) + }, + Name: "version", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Get the bridge version.", + }, +} + +var CommandCancel = &FullHandler{ + Func: func(ce *Event) { + commandingUser, ok := ce.User.(CommandingUser) + if !ok { + ce.Reply("This bridge does not implement cancelable commands") + return + } + state := commandingUser.GetCommandState() + + if state != nil { + action := state.Action + if action == "" { + action = "Unknown action" + } + commandingUser.SetCommandState(nil) + ce.Reply("%s cancelled.", action) + } else { + ce.Reply("No ongoing command.") + } + }, + Name: "cancel", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Cancel an ongoing action.", + }, +} diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go new file mode 100644 index 00000000..6158a7cd --- /dev/null +++ b/bridge/commands/processor.go @@ -0,0 +1,122 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "runtime/debug" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridge" + "maunium.net/go/mautrix/id" +) + +type Processor struct { + bridge *bridge.Bridge + log *zerolog.Logger + + handlers map[string]Handler + aliases map[string]string +} + +// NewProcessor creates a Processor +func NewProcessor(bridge *bridge.Bridge) *Processor { + proc := &Processor{ + bridge: bridge, + log: bridge.ZLog, + + handlers: make(map[string]Handler), + aliases: make(map[string]string), + } + proc.AddHandlers( + CommandHelp, CommandVersion, CommandCancel, + CommandLoginMatrix, CommandLogoutMatrix, CommandPingMatrix, + CommandDiscardMegolmSession, CommandSetPowerLevel) + return proc +} + +func (proc *Processor) AddHandlers(handlers ...Handler) { + for _, handler := range handlers { + proc.AddHandler(handler) + } +} + +func (proc *Processor) AddHandler(handler Handler) { + proc.handlers[handler.GetName()] = handler + aliased, ok := handler.(AliasedHandler) + if ok { + for _, alias := range aliased.GetAliases() { + proc.aliases[alias] = handler.GetName() + } + } +} + +// Handle handles messages to the bridge +func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridge.User, message string, replyTo id.EventID) { + defer func() { + err := recover() + if err != nil { + zerolog.Ctx(ctx).Error(). + Str(zerolog.ErrorStackFieldName, string(debug.Stack())). + Interface(zerolog.ErrorFieldName, err). + Msg("Panic in Matrix command handler") + } + }() + args := strings.Fields(message) + if len(args) == 0 { + args = []string{"unknown-command"} + } + command := strings.ToLower(args[0]) + rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") + log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() + ctx = log.WithContext(ctx) + ce := &Event{ + Bot: proc.bridge.Bot, + Bridge: proc.bridge, + Portal: proc.bridge.Child.GetIPortal(roomID), + Processor: proc, + RoomID: roomID, + EventID: eventID, + User: user, + Command: command, + Args: args[1:], + RawArgs: rawArgs, + ReplyTo: replyTo, + Ctx: ctx, + ZLog: &log, + } + log.Debug().Msg("Received command") + + realCommand, ok := proc.aliases[ce.Command] + if !ok { + realCommand = ce.Command + } + commandingUser, ok := ce.User.(CommandingUser) + + var handler MinimalHandler + handler, ok = proc.handlers[realCommand] + if !ok { + var state *CommandState + if commandingUser != nil { + state = commandingUser.GetCommandState() + } + if state != nil && state.Next != nil { + ce.Command = "" + ce.RawArgs = message + ce.Args = args + ce.Handler = state.Next + state.Next.Run(ce) + } else { + ce.Reply("Unknown command, use the `help` command for help.") + } + } else { + ce.Handler = handler + handler.Run(ce) + } +} diff --git a/bridge/crypto.go b/bridge/crypto.go new file mode 100644 index 00000000..f0b90056 --- /dev/null +++ b/bridge/crypto.go @@ -0,0 +1,511 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build cgo && !nocrypto + +package bridge + +import ( + "context" + "errors" + "fmt" + "os" + "runtime/debug" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/sqlstatestore" +) + +var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) + +var NoSessionFound = crypto.NoSessionFound +var DuplicateMessageIndex = crypto.DuplicateMessageIndex +var UnknownMessageIndex = olm.UnknownMessageIndex + +type CryptoHelper struct { + bridge *Bridge + client *mautrix.Client + mach *crypto.OlmMachine + store *SQLCryptoStore + log *zerolog.Logger + + lock sync.RWMutex + syncDone sync.WaitGroup + cancelSync func() + + cancelPeriodicDeleteLoop func() +} + +func NewCryptoHelper(bridge *Bridge) Crypto { + if !bridge.Config.Bridge.GetEncryptionConfig().Allow { + bridge.ZLog.Debug().Msg("Bridge built with end-to-bridge encryption, but disabled in config") + return nil + } + log := bridge.ZLog.With().Str("component", "crypto").Logger() + return &CryptoHelper{ + bridge: bridge, + log: &log, + } +} + +func (helper *CryptoHelper) Init(ctx context.Context) error { + if len(helper.bridge.CryptoPickleKey) == 0 { + panic("CryptoPickleKey not set") + } + helper.log.Debug().Msg("Initializing end-to-bridge encryption...") + + helper.store = NewSQLCryptoStore( + helper.bridge.DB, + dbutil.ZeroLogger(helper.bridge.ZLog.With().Str("db_section", "crypto").Logger()), + helper.bridge.AS.BotMXID(), + fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), + helper.bridge.CryptoPickleKey, + ) + + err := helper.store.DB.Upgrade(ctx) + if err != nil { + helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) + } + + var isExistingDevice bool + helper.client, isExistingDevice, err = helper.loginBot(ctx) + if err != nil { + return err + } + + helper.log.Debug(). + Str("device_id", helper.client.DeviceID.String()). + Msg("Logged in as bridge bot") + stateStore := &cryptoStateStore{helper.bridge} + helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore) + helper.mach.AllowKeyShare = helper.allowKeyShare + + encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig() + helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive + helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions + + helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck + helper.mach.DontStoreOutboundKeys = encryptionConfig.DeleteKeys.DontStoreOutbound + helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt + helper.mach.DeleteFullyUsedKeysOnDecrypt = encryptionConfig.DeleteKeys.DeleteFullyUsedOnDecrypt + helper.mach.DeletePreviousKeysOnReceive = encryptionConfig.DeleteKeys.DeletePrevOnNewSession + helper.mach.DeleteKeysOnDeviceDelete = encryptionConfig.DeleteKeys.DeleteOnDeviceDelete + helper.mach.DisableDeviceChangeKeyRotation = encryptionConfig.Rotation.DisableDeviceChangeKeyRotation + if encryptionConfig.DeleteKeys.PeriodicallyDeleteExpired { + ctx, cancel := context.WithCancel(context.Background()) + helper.cancelPeriodicDeleteLoop = cancel + go helper.mach.ExpiredKeyDeleteLoop(ctx) + } + + if encryptionConfig.DeleteKeys.DeleteOutdatedInbound { + deleted, err := helper.store.RedactOutdatedGroupSessions(ctx) + if err != nil { + return err + } + if len(deleted) > 0 { + helper.log.Debug().Int("deleted", len(deleted)).Msg("Deleted inbound keys which lacked expiration metadata") + } + } + + helper.client.Syncer = &cryptoSyncer{helper.mach} + helper.client.Store = helper.store + + err = helper.mach.Load(ctx) + if err != nil { + return err + } + if isExistingDevice { + helper.verifyKeysAreOnServer(ctx) + } + + go helper.resyncEncryptionInfo(context.TODO()) + + return nil +} + +func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { + log := helper.log.With().Str("action", "resync encryption event").Logger() + rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + if err != nil { + log.Err(err).Msg("Failed to query rooms for resync") + return + } + roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + if err != nil { + log.Err(err).Msg("Failed to scan rooms for resync") + return + } + if len(roomIDs) > 0 { + log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") + for _, roomID := range roomIDs { + var evt event.EncryptionEventContent + err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) + if err != nil { + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") + _, err = helper.bridge.DB.Exec(ctx, ` + UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' + `, roomID) + if err != nil { + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") + } + } else { + maxAge := evt.RotationPeriodMillis + if maxAge <= 0 { + maxAge = (7 * 24 * time.Hour).Milliseconds() + } + maxMessages := evt.RotationPeriodMessages + if maxMessages <= 0 { + maxMessages = 100 + } + log.Debug(). + Str("room_id", roomID.String()). + Int64("max_age_ms", maxAge). + Int("max_messages", maxMessages). + Interface("content", &evt). + Msg("Resynced encryption event") + _, err = helper.bridge.DB.Exec(ctx, ` + UPDATE crypto_megolm_inbound_session + SET max_age=$1, max_messages=$2 + WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL + `, maxAge, maxMessages, roomID) + if err != nil { + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") + } else { + log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") + } + } + } + } +} + +func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device, info event.RequestedKeyInfo) *crypto.KeyShareRejection { + cfg := helper.bridge.Config.Bridge.GetEncryptionConfig() + if !cfg.AllowKeySharing { + return &crypto.KeyShareRejectNoResponse + } else if device.Trust == id.TrustStateBlacklisted { + return &crypto.KeyShareRejectBlacklisted + } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { + portal := helper.bridge.Child.GetIPortal(info.RoomID) + if portal == nil { + zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: room is not a portal") + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} + } + user := helper.bridge.Child.GetIUser(device.UserID, true) + // FIXME reimplement IsInPortal + if user.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin /*&& !user.IsInPortal(portal.Key)*/ { + zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not in portal") + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"} + } + zerolog.Ctx(ctx).Debug().Msg("Accepting key request") + return nil + } else { + return &crypto.KeyShareRejectUnverified + } +} + +func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { + deviceID, err := helper.store.FindDeviceID(ctx) + if err != nil { + return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) + } else if len(deviceID) > 0 { + helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") + } + // Create a new client instance with the default AS settings (including as_token), + // the Login call will then override the access token in the client. + client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) + flows, err := client.GetLoginFlows(ctx) + if err != nil { + return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) + } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { + return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") + } + resp, err := client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypeAppservice, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: string(helper.bridge.AS.BotMXID()), + }, + DeviceID: deviceID, + StoreCredentials: true, + + InitialDeviceDisplayName: fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), + }) + if err != nil { + return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) + } + helper.store.DeviceID = resp.DeviceID + return client, deviceID != "", nil +} + +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { + helper.log.Debug().Msg("Making sure keys are still on server") + resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ + DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ + helper.client.UserID: {helper.client.DeviceID}, + }, + }) + if err != nil { + helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to query own keys to make sure device still exists") + os.Exit(33) + } + device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] + if ok && len(device.Keys) > 0 { + return + } + helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") + helper.Reset(ctx, false) +} + +func (helper *CryptoHelper) Start() { + if helper.bridge.Config.Bridge.GetEncryptionConfig().Appservice { + helper.log.Debug().Msg("End-to-bridge encryption is in appservice mode, registering event listeners and not starting syncer") + helper.bridge.AS.Registration.EphemeralEvents = true + helper.mach.AddAppserviceListener(helper.bridge.EventProcessor) + return + } + helper.syncDone.Add(1) + defer helper.syncDone.Done() + helper.log.Debug().Msg("Starting syncer for receiving to-device messages") + var ctx context.Context + ctx, helper.cancelSync = context.WithCancel(context.Background()) + err := helper.client.SyncWithContext(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Fatal error syncing") + os.Exit(51) + } else { + helper.log.Info().Msg("Bridge bot to-device syncer stopped without error") + } +} + +func (helper *CryptoHelper) Stop() { + helper.log.Debug().Msg("CryptoHelper.Stop() called, stopping bridge bot sync") + helper.client.StopSync() + if helper.cancelSync != nil { + helper.cancelSync() + } + if helper.cancelPeriodicDeleteLoop != nil { + helper.cancelPeriodicDeleteLoop() + } + helper.syncDone.Wait() +} + +func (helper *CryptoHelper) clearDatabase(ctx context.Context) { + _, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account") + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table") + } + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session") + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table") + } + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session") + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table") + } + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_device") + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_tracked_user") + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_keys") + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures") +} + +func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) { + helper.lock.Lock() + defer helper.lock.Unlock() + helper.log.Info().Msg("Resetting end-to-bridge encryption device") + helper.Stop() + helper.log.Debug().Msg("Crypto syncer stopped, clearing database") + helper.clearDatabase(ctx) + helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") + _, err := helper.client.LogoutAll(ctx) + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to log out all devices") + } + helper.client = nil + helper.store = nil + helper.mach = nil + err = helper.Init(ctx) + if err != nil { + helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption") + os.Exit(50) + } + helper.log.Info().Msg("End-to-bridge encryption successfully reset") + if startAfterReset { + go helper.Start() + } +} + +func (helper *CryptoHelper) Client() *mautrix.Client { + return helper.client +} + +func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return helper.mach.DecryptMegolmEvent(ctx, evt) +} + +func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { + helper.lock.RLock() + defer helper.lock.RUnlock() + var encrypted *event.EncryptedEventContent + encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) + if err != nil { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + return + } + helper.log.Debug().Err(err). + Str("room_id", roomID.String()). + Msg("Got error while encrypting event for room, sharing group session and trying again...") + var users []id.UserID + users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to get room member list: %w", err) + } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + if encrypted != nil { + content.Parsed = encrypted + content.Raw = nil + } + return +} + +func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + helper.lock.RLock() + defer helper.lock.RUnlock() + return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) +} + +func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + helper.lock.RLock() + defer helper.lock.RUnlock() + if deviceID == "" { + deviceID = "*" + } + err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) + if err != nil { + helper.log.Warn().Err(err). + Str("user_id", userID.String()). + Str("device_id", deviceID.String()). + Str("session_id", sessionID.String()). + Str("room_id", roomID.String()). + Msg("Failed to send key request") + } else { + helper.log.Debug(). + Str("user_id", userID.String()). + Str("device_id", deviceID.String()). + Str("session_id", sessionID.String()). + Str("room_id", roomID.String()). + Msg("Sent key request") + } +} + +func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) { + helper.lock.RLock() + defer helper.lock.RUnlock() + err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) + if err != nil { + helper.log.Debug().Err(err). + Str("room_id", roomID.String()). + Msg("Error manually removing outbound group session in room") + } +} + +func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) { + helper.lock.RLock() + defer helper.lock.RUnlock() + helper.mach.HandleMemberEvent(ctx, evt) +} + +// ShareKeys uploads the given number of one-time-keys to the server. +func (helper *CryptoHelper) ShareKeys(ctx context.Context) error { + return helper.mach.ShareKeys(ctx, -1) +} + +type cryptoSyncer struct { + *crypto.OlmMachine +} + +func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + done := make(chan struct{}) + go func() { + defer func() { + if err := recover(); err != nil { + syncer.Log.Error(). + Str("since", since). + Interface("error", err). + Str("stack", string(debug.Stack())). + Msg("Processing sync response panicked") + } + done <- struct{}{} + }() + syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling") + syncer.ProcessSyncResponse(ctx, resp, since) + syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response") + }() + select { + case <-done: + case <-time.After(30 * time.Second): + syncer.Log.Warn().Str("since", since).Msg("Handling sync response is taking unusually long") + } + return nil +} + +func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + if errors.Is(err, mautrix.MUnknownToken) { + return 0, err + } + syncer.Log.Error().Err(err).Msg("Error /syncing, waiting 10 seconds") + return 10 * time.Second, nil +} + +func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + everything := []event.Type{{Type: "*"}} + return &mautrix.Filter{ + Presence: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + Room: mautrix.RoomFilter{ + IncludeLeave: false, + Ephemeral: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + State: mautrix.FilterPart{NotTypes: everything}, + Timeline: mautrix.FilterPart{NotTypes: everything}, + }, + } +} + +type cryptoStateStore struct { + bridge *Bridge +} + +var _ crypto.StateStore = (*cryptoStateStore)(nil) + +func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) { + portal := c.bridge.Child.GetIPortal(id) + if portal != nil { + return portal.IsEncrypted(), nil + } + return c.bridge.StateStore.IsEncrypted(ctx, id) +} + +func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) { + return c.bridge.StateStore.FindSharedRooms(ctx, id) +} + +func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) { + return c.bridge.StateStore.GetEncryptionEvent(ctx, id) +} diff --git a/bridge/cryptostore.go b/bridge/cryptostore.go new file mode 100644 index 00000000..dde48a25 --- /dev/null +++ b/bridge/cryptostore.go @@ -0,0 +1,63 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build cgo && !nocrypto + +package bridge + +import ( + "context" + + "github.com/lib/pq" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/id" +) + +func init() { + crypto.PostgresArrayWrapper = pq.Array +} + +type SQLCryptoStore struct { + *crypto.SQLCryptoStore + UserID id.UserID + GhostIDFormat string +} + +var _ crypto.Store = (*SQLCryptoStore)(nil) + +func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id.UserID, ghostIDFormat, pickleKey string) *SQLCryptoStore { + return &SQLCryptoStore{ + SQLCryptoStore: crypto.NewSQLCryptoStore(db, log, "", "", []byte(pickleKey)), + UserID: userID, + GhostIDFormat: ghostIDFormat, + } +} + +func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { + var rows dbutil.Rows + rows, err = store.DB.Query(ctx, ` + SELECT user_id FROM mx_user_profile + WHERE room_id=$1 + AND (membership='join' OR membership='invite') + AND user_id<>$2 + AND user_id NOT LIKE $3 + `, roomID, store.UserID, store.GhostIDFormat) + if err != nil { + return + } + for rows.Next() { + var userID id.UserID + err = rows.Scan(&userID) + if err != nil { + return members, err + } else { + members = append(members, userID) + } + } + return +} diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go new file mode 100644 index 00000000..265d3d5c --- /dev/null +++ b/bridge/doublepuppet.go @@ -0,0 +1,173 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridge + +import ( + "context" + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "errors" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type doublePuppetUtil struct { + br *Bridge + log zerolog.Logger +} + +func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { + _, homeserver, err := mxid.Parse() + if err != nil { + return nil, err + } + homeserverURL, found := dp.br.Config.Bridge.GetDoublePuppetConfig().ServerMap[homeserver] + if !found { + if homeserver == dp.br.AS.HomeserverDomain { + homeserverURL = "" + } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { + resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) + if err != nil { + return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) + } + homeserverURL = resp.Homeserver.BaseURL + dp.log.Debug(). + Str("homeserver", homeserver). + Str("url", homeserverURL). + Str("user_id", mxid.String()). + Msg("Discovered URL to enable double puppeting for user") + } else { + return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) + } + } + return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) +} + +func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { + client, err := dp.newClient(ctx, mxid, accessToken) + if err != nil { + return nil, err + } + + ia := dp.br.AS.NewIntentAPI("custom") + ia.Client = client + ia.Localpart, _, _ = mxid.Parse() + ia.UserID = mxid + ia.IsCustomPuppet = true + return ia, nil +} + +func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { + dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") + client, err := dp.newClient(ctx, mxid, "") + if err != nil { + return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) + } + bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) + req := mautrix.ReqLogin{ + Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, + DeviceID: id.DeviceID(bridgeName), + InitialDeviceDisplayName: bridgeName, + } + if loginSecret == "appservice" { + client.AccessToken = dp.br.AS.Registration.AppToken + req.Type = mautrix.AuthTypeAppservice + } else { + loginFlows, err := client.GetLoginFlows(ctx) + if err != nil { + return "", fmt.Errorf("failed to get supported login flows: %w", err) + } + mac := hmac.New(sha512.New, []byte(loginSecret)) + mac.Write([]byte(mxid)) + token := hex.EncodeToString(mac.Sum(nil)) + switch { + case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): + req.Type = mautrix.AuthTypeDevtureSharedSecret + req.Token = token + case loginFlows.HasFlow(mautrix.AuthTypePassword): + req.Type = mautrix.AuthTypePassword + req.Password = token + default: + return "", fmt.Errorf("no supported auth types for shared secret auth found") + } + } + resp, err := client.Login(ctx, &req) + if err != nil { + return "", err + } + return resp.AccessToken, nil +} + +var ( + ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") + ErrNoAccessToken = errors.New("no access token provided") + ErrNoMXID = errors.New("no mxid provided") +) + +const useConfigASToken = "appservice-config" +const asTokenModePrefix = "as_token:" + +func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { + if len(mxid) == 0 { + err = ErrNoMXID + return + } + _, homeserver, _ := mxid.Parse() + loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] + // Special case appservice: prefix to not login and use it as an as_token directly. + if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { + intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) + if err != nil { + return + } + intent.SetAppServiceUserID = true + if savedAccessToken != useConfigASToken { + var resp *mautrix.RespWhoami + resp, err = intent.Whoami(ctx) + if err == nil && resp.UserID != mxid { + err = ErrMismatchingMXID + } + } + return intent, useConfigASToken, err + } + if savedAccessToken == "" || savedAccessToken == useConfigASToken { + if reloginOnFail && hasSecret { + savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) + } else { + err = ErrNoAccessToken + } + if err != nil { + return + } + } + intent, err = dp.newIntent(ctx, mxid, savedAccessToken) + if err != nil { + return + } + var resp *mautrix.RespWhoami + resp, err = intent.Whoami(ctx) + if err != nil { + if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { + intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) + if err == nil { + newAccessToken = intent.AccessToken + } + } + } else if resp.UserID != mxid { + err = ErrMismatchingMXID + } else { + newAccessToken = savedAccessToken + } + return +} diff --git a/bridge/matrix.go b/bridge/matrix.go new file mode 100644 index 00000000..446a0b0a --- /dev/null +++ b/bridge/matrix.go @@ -0,0 +1,755 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridge + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +type CommandProcessor interface { + Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user User, message string, replyTo id.EventID) +} + +type MatrixHandler struct { + bridge *Bridge + as *appservice.AppService + log *zerolog.Logger + + TrackEventDuration func(event.Type) func() +} + +func noop() {} + +func noopTrack(_ event.Type) func() { + return noop +} + +func NewMatrixHandler(br *Bridge) *MatrixHandler { + handler := &MatrixHandler{ + bridge: br, + as: br.AS, + log: br.ZLog, + + TrackEventDuration: noopTrack, + } + for evtType := range status.CheckpointTypes { + br.EventProcessor.On(evtType, handler.sendBridgeCheckpoint) + } + br.EventProcessor.On(event.EventMessage, handler.HandleMessage) + br.EventProcessor.On(event.EventEncrypted, handler.HandleEncrypted) + br.EventProcessor.On(event.EventSticker, handler.HandleMessage) + br.EventProcessor.On(event.EventReaction, handler.HandleReaction) + br.EventProcessor.On(event.EventRedaction, handler.HandleRedaction) + br.EventProcessor.On(event.StateMember, handler.HandleMembership) + br.EventProcessor.On(event.StateRoomName, handler.HandleRoomMetadata) + br.EventProcessor.On(event.StateRoomAvatar, handler.HandleRoomMetadata) + br.EventProcessor.On(event.StateTopic, handler.HandleRoomMetadata) + br.EventProcessor.On(event.StateEncryption, handler.HandleEncryption) + br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt) + br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping) + br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels) + br.EventProcessor.On(event.StateJoinRules, handler.HandleJoinRule) + return handler +} + +func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { + if !evt.Mautrix.CheckpointSent { + go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) + } +} + +func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) { + defer mx.TrackEventDuration(evt.Type)() + if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { + return + } + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal != nil && !portal.IsEncrypted() { + mx.log.Debug(). + Str("user_id", evt.Sender.String()). + Str("room_id", evt.RoomID.String()). + Msg("Encryption was enabled in room") + portal.MarkEncrypted() + if portal.IsPrivateChat() { + err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) + if err != nil { + mx.log.Err(err). + Str("room_id", evt.RoomID.String()). + Msg("Failed to join bot to room after encryption was enabled") + } + } + } +} + +func (mx *MatrixHandler) joinAndCheckMembers(ctx context.Context, evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { + log := zerolog.Ctx(ctx) + resp, err := intent.JoinRoomByID(ctx, evt.RoomID) + if err != nil { + log.Warn().Err(err).Msg("Failed to join room with invite") + return nil + } + + members, err := intent.JoinedMembers(ctx, resp.RoomID) + if err != nil { + log.Warn().Err(err).Msg("Failed to get members in room after accepting invite, leaving room") + _, _ = intent.LeaveRoom(ctx, resp.RoomID) + return nil + } + + if len(members.Joined) < 2 { + log.Debug().Msg("Leaving empty room after accepting invite") + _, _ = intent.LeaveRoom(ctx, resp.RoomID) + return nil + } + return members +} + +func (mx *MatrixHandler) sendNoticeWithMarkdown(ctx context.Context, roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { + intent := mx.as.BotIntent() + content := format.RenderMarkdown(message, true, false) + content.MsgType = event.MsgNotice + return intent.SendMessageEvent(ctx, roomID, event.EventMessage, content) +} + +func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) { + intent := mx.as.BotIntent() + + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil { + return + } + + members := mx.joinAndCheckMembers(ctx, evt, intent) + if members == nil { + return + } + + if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { + _, _ = intent.SendNotice(ctx, evt.RoomID, "You are not whitelisted to use this bridge.\n"+ + "If you're the owner of this bridge, see the bridge.permissions section in your config file.") + _, _ = intent.LeaveRoom(ctx, evt.RoomID) + return + } + + texts := mx.bridge.Config.Bridge.GetManagementRoomTexts() + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.Welcome) + + if len(members.Joined) == 2 && (len(user.GetManagementRoomID()) == 0 || evt.Content.AsMember().IsDirect) { + user.SetManagementRoom(evt.RoomID) + _, _ = intent.SendNotice(ctx, user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") + zerolog.Ctx(ctx).Debug().Msg("Registered room as management room with inviter") + } + + if evt.RoomID == user.GetManagementRoomID() { + if user.IsLoggedIn() { + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeConnected) + } else { + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeUnconnected) + } + + additionalHelp := texts.AdditionalHelp + if len(additionalHelp) > 0 { + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, additionalHelp) + } + } +} + +func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event, inviter User, ghost Ghost) { + log := zerolog.Ctx(ctx) + intent := ghost.DefaultIntent() + + if inviter.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { + log.Debug().Msg("Rejecting invite: inviter is not whitelisted") + _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ + Reason: "You're not whitelisted to use this bridge", + }) + if err != nil { + log.Error().Err(err).Msg("Failed to reject invite") + } + return + } else if !inviter.IsLoggedIn() { + log.Debug().Msg("Rejecting invite: inviter is not logged in") + _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ + Reason: "You're not logged into this bridge", + }) + if err != nil { + log.Error().Err(err).Msg("Failed to reject invite") + } + return + } + + members := mx.joinAndCheckMembers(ctx, evt, intent) + if members == nil { + return + } + var createEvent event.CreateEventContent + if err := intent.StateEvent(ctx, evt.RoomID, event.StateCreate, "", &createEvent); err != nil { + log.Warn().Err(err).Msg("Failed to check m.room.create event in room") + } else if createEvent.Type != "" { + log.Warn().Str("room_type", string(createEvent.Type)).Msg("Non-standard room type, leaving room") + _, err = intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ + Reason: "Unsupported room type", + }) + if err != nil { + log.Error().Err(err).Msg("Failed to leave room") + } + return + } + var hasBridgeBot, hasOtherUsers bool + for mxid, _ := range members.Joined { + if mxid == intent.UserID || mxid == inviter.GetMXID() { + continue + } else if mxid == mx.bridge.Bot.UserID { + hasBridgeBot = true + } else { + hasOtherUsers = true + } + } + if !hasBridgeBot && !hasOtherUsers && evt.Content.AsMember().IsDirect { + mx.bridge.Child.CreatePrivatePortal(evt.RoomID, inviter, ghost) + } else if !hasBridgeBot { + log.Debug().Msg("Leaving multi-user room after accepting invite") + _, _ = intent.SendNotice(ctx, evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") + _, _ = intent.LeaveRoom(ctx, evt.RoomID) + } else { + _, _ = intent.SendNotice(ctx, evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") + } +} + +func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) { + if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { + return + } + defer mx.TrackEventDuration(evt.Type)() + + if mx.bridge.Crypto != nil { + mx.bridge.Crypto.HandleMemberEvent(ctx, evt) + } + + log := mx.log.With(). + Str("sender", evt.Sender.String()). + Str("target", evt.GetStateKey()). + Str("room_id", evt.RoomID.String()). + Logger() + ctx = log.WithContext(ctx) + + content := evt.Content.AsMember() + if content.Membership == event.MembershipInvite && id.UserID(evt.GetStateKey()) == mx.as.BotMXID() { + mx.HandleBotInvite(ctx, evt) + return + } + + if mx.shouldIgnoreEvent(evt) { + return + } + + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil { + return + } + isSelf := id.UserID(evt.GetStateKey()) == evt.Sender + ghost := mx.bridge.Child.GetIGhost(id.UserID(evt.GetStateKey())) + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + if ghost != nil && content.Membership == event.MembershipInvite { + mx.HandleGhostInvite(ctx, evt, user, ghost) + } + return + } else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { + return + } + bhp, bhpOk := portal.(BanHandlingPortal) + mhp, mhpOk := portal.(MembershipHandlingPortal) + khp, khpOk := portal.(KnockHandlingPortal) + ihp, ihpOk := portal.(InviteHandlingPortal) + if !(mhpOk || bhpOk || khpOk) { + return + } + prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) + } + if ihpOk && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { + if content.Membership == event.MembershipJoin { + ihp.HandleMatrixAcceptInvite(user, evt) + } + if content.Membership == event.MembershipLeave { + if isSelf { + ihp.HandleMatrixRejectInvite(user, evt) + } else if ghost != nil { + ihp.HandleMatrixRetractInvite(user, ghost, evt) + } + } + } + if bhpOk && ghost != nil { + if content.Membership == event.MembershipBan { + bhp.HandleMatrixBan(user, ghost, evt) + } else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan { + bhp.HandleMatrixUnban(user, ghost, evt) + } + } + if khpOk { + if content.Membership == event.MembershipKnock { + khp.HandleMatrixKnock(user, evt) + } else if prevContent.Membership == event.MembershipKnock { + if content.Membership == event.MembershipInvite && ghost != nil { + khp.HandleMatrixAcceptKnock(user, ghost, evt) + } else if content.Membership == event.MembershipLeave { + if isSelf { + khp.HandleMatrixRetractKnock(user, evt) + } else if ghost != nil { + khp.HandleMatrixRejectKnock(user, ghost, evt) + } + } + } + } + if mhpOk { + if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin { + if isSelf { + mhp.HandleMatrixLeave(user, evt) + } else if ghost != nil { + mhp.HandleMatrixKick(user, ghost, evt) + } + } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { + mhp.HandleMatrixInvite(user, ghost, evt) + } + } + // TODO kicking/inviting non-ghost users users +} + +func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) { + defer mx.TrackEventDuration(evt.Type)() + if mx.shouldIgnoreEvent(evt) { + return + } + + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil { + return + } + + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil || portal.IsPrivateChat() { + return + } + + metaPortal, ok := portal.(MetaHandlingPortal) + if !ok { + return + } + + metaPortal.HandleMatrixMeta(user, evt) +} + +func (mx *MatrixHandler) shouldIgnoreEvent(evt *event.Event) bool { + if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { + return true + } + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil || user.GetPermissionLevel() <= 0 { + return true + } else if val, ok := evt.Content.Raw[appservice.DoublePuppetKey]; ok && val == mx.bridge.Name && user.GetIDoublePuppet() != nil { + return true + } + return false +} + +const initialSessionWaitTimeout = 3 * time.Second +const extendedSessionWaitTimeout = 22 * time.Second + +func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.Event, editEvent id.EventID, err error, retryCount int, isFinal bool) id.EventID { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, isFinal, retryCount) + + if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { + statusEvent := &event.BeeperMessageStatusEventContent{ + // TODO: network + RelatesTo: event.RelatesTo{ + Type: event.RelReference, + EventID: evt.ID, + }, + Status: event.MessageStatusRetriable, + Reason: event.MessageStatusUndecryptable, + Error: err.Error(), + Message: errorToHumanMessage(err), + } + if !isFinal { + statusEvent.Status = event.MessageStatusPending + } + _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) + if sendErr != nil { + zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to send message status event") + } + } + if mx.bridge.Config.Bridge.EnableMessageErrorNotices() { + update := event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("\u26a0 Your message was not bridged: %v.", err), + } + if errors.Is(err, errNoCrypto) { + update.Body = "🔒 This bridge has not been configured to support encryption" + } + relatable, ok := evt.Content.Parsed.(event.Relatable) + if editEvent != "" { + update.SetEdit(editEvent) + } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { + update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) + } + resp, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, &update) + if sendErr != nil { + zerolog.Ctx(ctx).Error().Err(sendErr).Msg("Failed to send decryption error notice") + } else if resp != nil { + return resp.EventID + } + } + return "" +} + +var ( + errDeviceNotTrusted = errors.New("your device is not trusted") + errMessageNotEncrypted = errors.New("unencrypted message") + errNoDecryptionKeys = errors.New("the bridge hasn't received the decryption keys") + errNoCrypto = errors.New("this bridge has not been configured to support encryption") +) + +func errorToHumanMessage(err error) string { + var withheld *event.RoomKeyWithheldEventContent + switch { + case errors.Is(err, errDeviceNotTrusted), errors.Is(err, errNoDecryptionKeys): + return err.Error() + case errors.Is(err, UnknownMessageIndex): + return "the keys received by the bridge can't decrypt the message" + case errors.Is(err, DuplicateMessageIndex): + return "your client encrypted multiple messages with the same key" + case errors.As(err, &withheld): + if withheld.Code == event.RoomKeyWithheldBeeperRedacted { + return "your client used an outdated encryption session" + } + return "your client refused to share decryption keys with the bridge" + case errors.Is(err, errMessageNotEncrypted): + return "the message is not encrypted" + default: + return "the bridge failed to decrypt the message" + } +} + +func deviceUnverifiedErrorWithExplanation(trust id.TrustState) error { + var explanation string + switch trust { + case id.TrustStateBlacklisted: + explanation = "device is blacklisted" + case id.TrustStateUnset: + explanation = "unverified" + case id.TrustStateUnknownDevice: + explanation = "device info not found" + case id.TrustStateForwarded: + explanation = "keys were forwarded from an unknown device" + case id.TrustStateCrossSignedUntrusted: + explanation = "cross-signing keys changed after setting up the bridge" + default: + return errDeviceNotTrusted + } + return fmt.Errorf("%w (%s)", errDeviceNotTrusted, explanation) +} + +func copySomeKeys(original, decrypted *event.Event) { + isScheduled, _ := original.Content.Raw["com.beeper.scheduled"].(bool) + _, alreadyExists := decrypted.Content.Raw["com.beeper.scheduled"] + if isScheduled && !alreadyExists { + decrypted.Content.Raw["com.beeper.scheduled"] = true + } +} + +func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID id.EventID, duration time.Duration) { + log := zerolog.Ctx(ctx) + minLevel := mx.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Send + if decrypted.Mautrix.TrustState < minLevel { + logEvt := log.Warn(). + Str("user_id", decrypted.Sender.String()). + Bool("forwarded_keys", decrypted.Mautrix.ForwardedKeys). + Stringer("device_trust", decrypted.Mautrix.TrustState). + Stringer("min_trust", minLevel) + if decrypted.Mautrix.TrustSource != nil { + dev := decrypted.Mautrix.TrustSource + logEvt. + Str("device_id", dev.DeviceID.String()). + Str("device_signing_key", dev.SigningKey.String()) + } else { + logEvt.Str("device_id", "unknown") + } + logEvt.Msg("Dropping event due to insufficient verification level") + err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) + go mx.sendCryptoStatusError(ctx, decrypted, errorEventID, err, retryCount, true) + return + } + copySomeKeys(original, decrypted) + + mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount) + decrypted.Mautrix.CheckpointSent = true + decrypted.Mautrix.DecryptionDuration = duration + decrypted.Mautrix.EventSource |= event.SourceDecrypted + mx.bridge.EventProcessor.Dispatch(ctx, decrypted) + if errorEventID != "" { + _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) + } +} + +func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) { + defer mx.TrackEventDuration(evt.Type)() + if mx.shouldIgnoreEvent(evt) { + return + } + content := evt.Content.AsEncrypted() + log := zerolog.Ctx(ctx).With(). + Str("event_id", evt.ID.String()). + Str("session_id", content.SessionID.String()). + Logger() + ctx = log.WithContext(ctx) + if mx.bridge.Crypto == nil { + go mx.sendCryptoStatusError(ctx, evt, "", errNoCrypto, 0, true) + return + } + log.Debug().Msg("Decrypting received event") + + decryptionStart := time.Now() + decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) + decryptionRetryCount := 0 + if errors.Is(err, NoSessionFound) { + decryptionRetryCount = 1 + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0) + if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt) + } else { + go mx.waitLongerForSession(ctx, evt, decryptionStart) + return + } + } + if err != nil { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, true, decryptionRetryCount) + log.Warn().Err(err).Msg("Failed to decrypt event") + go mx.sendCryptoStatusError(ctx, evt, "", err, decryptionRetryCount, true) + return + } + mx.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, "", time.Since(decryptionStart)) +} + +func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { + log := zerolog.Ctx(ctx) + content := evt.Content.AsEncrypted() + log.Debug(). + Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, requesting keys and waiting longer...") + + go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) + + if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + log.Debug().Msg("Didn't get session, giving up trying to decrypt event") + mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true) + return + } + + log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") + decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) + if err != nil { + log.Error().Err(err).Msg("Failed to decrypt event") + mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true) + return + } + + mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) +} + +func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) { + defer mx.TrackEventDuration(evt.Type)() + log := zerolog.Ctx(ctx).With(). + Str("event_id", evt.ID.String()). + Str("room_id", evt.RoomID.String()). + Str("sender", evt.Sender.String()). + Logger() + ctx = log.WithContext(ctx) + if mx.shouldIgnoreEvent(evt) { + return + } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { + log.Warn().Msg("Dropping unencrypted event") + mx.sendCryptoStatusError(ctx, evt, "", errMessageNotEncrypted, 0, true) + return + } + + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil { + return + } + + content := evt.Content.AsMessage() + content.RemoveReplyFallback() + if user.GetPermissionLevel() >= bridgeconfig.PermissionLevelUser && content.MsgType == event.MsgText { + commandPrefix := mx.bridge.Config.Bridge.GetCommandPrefix() + hasCommandPrefix := strings.HasPrefix(content.Body, commandPrefix) + if hasCommandPrefix { + content.Body = strings.TrimLeft(strings.TrimPrefix(content.Body, commandPrefix), " ") + } + if hasCommandPrefix || evt.RoomID == user.GetManagementRoomID() { + go mx.bridge.CommandProcessor.Handle(ctx, evt.RoomID, evt.ID, user, content.Body, content.RelatesTo.GetReplyTo()) + go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepCommand, 0) + if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { + statusEvent := &event.BeeperMessageStatusEventContent{ + // TODO: network + RelatesTo: event.RelatesTo{ + Type: event.RelReference, + EventID: evt.ID, + }, + Status: event.MessageStatusSuccess, + } + _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) + if sendErr != nil { + log.Warn().Err(sendErr).Msg("Failed to send message status event for command") + } + } + return + } + } + + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal != nil { + portal.ReceiveMatrixEvent(user, evt) + } else { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) + } +} + +func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) { + defer mx.TrackEventDuration(evt.Type)() + if mx.shouldIgnoreEvent(evt) { + return + } + + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil || user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { + return + } + + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal != nil { + portal.ReceiveMatrixEvent(user, evt) + } else { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) + } +} + +func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) { + defer mx.TrackEventDuration(evt.Type)() + if mx.shouldIgnoreEvent(evt) { + return + } + + user := mx.bridge.Child.GetIUser(evt.Sender, true) + if user == nil { + return + } + + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal != nil { + portal.ReceiveMatrixEvent(user, evt) + } else { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) + } +} + +func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) { + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + return + } + + rrPortal, ok := portal.(ReadReceiptHandlingPortal) + if !ok { + return + } + + for eventID, receipts := range *evt.Content.AsReceipt() { + for userID, receipt := range receipts[event.ReceiptTypeRead] { + user := mx.bridge.Child.GetIUser(userID, false) + if user == nil { + // Not a bridge user + continue + } + customPuppet := user.GetIDoublePuppet() + if val, ok := receipt.Extra[appservice.DoublePuppetKey].(string); ok && customPuppet != nil && val == mx.bridge.Name { + // Ignore double puppeted read receipts. + mx.log.Debug().Interface("content", evt.Content.Raw).Msg("Ignoring double-puppeted read receipt") + // But do start disappearing messages, because the user read the chat + dp, ok := portal.(DisappearingPortal) + if ok { + dp.ScheduleDisappearing() + } + } else { + rrPortal.HandleMatrixReadReceipt(user, eventID, receipt) + } + } + } +} + +func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + return + } + typingPortal, ok := portal.(TypingPortal) + if !ok { + return + } + typingPortal.HandleMatrixTyping(evt.Content.AsTyping().UserIDs) +} + +func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event) { + if mx.shouldIgnoreEvent(evt) { + return + } + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + return + } + powerLevelPortal, ok := portal.(PowerLevelHandlingPortal) + if ok { + user := mx.bridge.Child.GetIUser(evt.Sender, true) + powerLevelPortal.HandleMatrixPowerLevels(user, evt) + } +} + +func (mx *MatrixHandler) HandleJoinRule(_ context.Context, evt *event.Event) { + if mx.shouldIgnoreEvent(evt) { + return + } + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + return + } + joinRulePortal, ok := portal.(JoinRuleHandlingPortal) + if ok { + user := mx.bridge.Child.GetIUser(evt.Sender, true) + joinRulePortal.HandleMatrixJoinRule(user, evt) + } +} diff --git a/bridge/messagecheckpoint.go b/bridge/messagecheckpoint.go new file mode 100644 index 00000000..a95d2160 --- /dev/null +++ b/bridge/messagecheckpoint.go @@ -0,0 +1,61 @@ +// Copyright (c) 2021 Sumner Evans +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridge + +import ( + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/event" +) + +func (br *Bridge) SendMessageSuccessCheckpoint(evt *event.Event, step status.MessageCheckpointStep, retryNum int) { + br.SendMessageCheckpoint(evt, step, nil, status.MsgStatusSuccess, retryNum) +} + +func (br *Bridge) SendMessageErrorCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, permanent bool, retryNum int) { + s := status.MsgStatusWillRetry + if permanent { + s = status.MsgStatusPermFailure + } + br.SendMessageCheckpoint(evt, step, err, s, retryNum) +} + +func (br *Bridge) SendMessageCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, s status.MessageCheckpointStatus, retryNum int) { + checkpoint := status.NewMessageCheckpoint(evt, step, s, retryNum) + if err != nil { + checkpoint.Info = err.Error() + } + go br.SendRawMessageCheckpoint(checkpoint) +} + +func (br *Bridge) SendRawMessageCheckpoint(cp *status.MessageCheckpoint) { + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{cp}) + if err != nil { + br.ZLog.Warn().Err(err).Interface("message_checkpoint", cp).Msg("Error sending message checkpoint") + } else { + br.ZLog.Debug().Interface("message_checkpoint", cp).Msg("Sent message checkpoint") + } +} + +func (br *Bridge) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { + checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} + + if br.Websocket { + return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + Command: "message_checkpoint", + Data: checkpointsJSON, + }) + } + + endpoint := br.Config.Homeserver.MessageSendCheckpointEndpoint + if endpoint == "" { + return nil + } + + return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) +} diff --git a/bridge/no-crypto.go b/bridge/no-crypto.go new file mode 100644 index 00000000..019ab7c1 --- /dev/null +++ b/bridge/no-crypto.go @@ -0,0 +1,26 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build !cgo || nocrypto + +package bridge + +import ( + "errors" +) + +func NewCryptoHelper(bridge *Bridge) Crypto { + if bridge.Config.Bridge.GetEncryptionConfig().Allow { + bridge.ZLog.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") + } else { + bridge.ZLog.Debug().Msg("Bridge built without end-to-bridge encryption") + } + return nil +} + +var NoSessionFound = errors.New("nil") +var UnknownMessageIndex = NoSessionFound +var DuplicateMessageIndex = NoSessionFound diff --git a/bridgev2/status/bridgestate.go b/bridge/status/bridgestate.go similarity index 75% rename from bridgev2/status/bridgestate.go rename to bridge/status/bridgestate.go index 5925dd4f..72e61415 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -12,17 +12,16 @@ import ( "encoding/json" "fmt" "io" - "maps" "net/http" "reflect" "time" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "golang.org/x/exp/maps" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -54,42 +53,12 @@ const ( StateLoggedOut BridgeStateEvent = "LOGGED_OUT" ) -func (e BridgeStateEvent) IsValid() bool { - switch e { - case - StateStarting, - StateUnconfigured, - StateRunning, - StateBridgeUnreachable, - StateConnecting, - StateBackfilling, - StateConnected, - StateTransientDisconnect, - StateBadCredentials, - StateUnknownError, - StateLoggedOut: - return true - default: - return false - } -} - -type BridgeStateUserAction string - -const ( - UserActionOpenNative BridgeStateUserAction = "OPEN_NATIVE" - UserActionRelogin BridgeStateUserAction = "RELOGIN" - UserActionRestart BridgeStateUserAction = "RESTART" -) - type RemoteProfile struct { Phone string `json:"phone,omitempty"` Email string `json:"email,omitempty"` Username string `json:"username,omitempty"` Name string `json:"name,omitempty"` Avatar id.ContentURIString `json:"avatar,omitempty"` - - AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"` } func coalesce[T ~string](a, b T) T { @@ -105,14 +74,11 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { other.Username = coalesce(rp.Username, other.Username) other.Name = coalesce(rp.Name, other.Name) other.Avatar = coalesce(rp.Avatar, other.Avatar) - if rp.AvatarFile != nil { - other.AvatarFile = rp.AvatarFile - } return other } -func (rp *RemoteProfile) IsZero() bool { - return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) +func (rp *RemoteProfile) IsEmpty() bool { + return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "") } type BridgeState struct { @@ -124,12 +90,10 @@ type BridgeState struct { Error BridgeStateErrorCode `json:"error,omitempty"` Message string `json:"message,omitempty"` - UserAction BridgeStateUserAction `json:"user_action,omitempty"` - - UserID id.UserID `json:"user_id,omitempty"` - RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` - RemoteProfile RemoteProfile `json:"remote_profile,omitzero"` + UserID id.UserID `json:"user_id,omitempty"` + RemoteID string `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -141,15 +105,31 @@ type GlobalBridgeState struct { } type BridgeStateFiller interface { + GetMXID() id.UserID + GetRemoteID() string + GetRemoteName() string +} + +type StandaloneCustomBridgeStateFiller interface { FillBridgeState(BridgeState) BridgeState } -// Deprecated: use BridgeStateFiller instead -type StandaloneCustomBridgeStateFiller = BridgeStateFiller +type CustomBridgeStateFiller interface { + BridgeStateFiller + StandaloneCustomBridgeStateFiller +} -func (pong BridgeState) Fill(user BridgeStateFiller) BridgeState { +func (pong BridgeState) Fill(user any) BridgeState { if user != nil { - pong = user.FillBridgeState(pong) + if std, ok := user.(BridgeStateFiller); ok { + pong.UserID = std.GetMXID() + pong.RemoteID = std.GetRemoteID() + pong.RemoteName = std.GetRemoteName() + } + + if custom, ok := user.(StandaloneCustomBridgeStateFiller); ok { + pong = custom.FillBridgeState(pong) + } } pong.Timestamp = jsontime.UnixNow() @@ -207,9 +187,7 @@ func (pong *BridgeState) SendHTTP(ctx context.Context, url, token string) error func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && - pong.RemoteName == newPong.RemoteName && - pong.UserAction == newPong.UserAction && - pong.RemoteProfile == newPong.RemoteProfile && + ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/status/messagecheckpoint.go b/bridge/status/messagecheckpoint.go similarity index 96% rename from bridgev2/status/messagecheckpoint.go rename to bridge/status/messagecheckpoint.go index b3c05f4f..ea859b84 100644 --- a/bridgev2/status/messagecheckpoint.go +++ b/bridge/status/messagecheckpoint.go @@ -169,13 +169,13 @@ type CheckpointsJSON struct { Checkpoints []*MessageCheckpoint `json:"checkpoints"` } -func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpoint string, token string) error { +func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { var body bytes.Buffer if err := json.NewEncoder(&body).Encode(cj); err != nil { return fmt.Errorf("failed to encode message checkpoint JSON: %w", err) } - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &body) if err != nil { @@ -186,10 +186,7 @@ func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpo req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (checkpoint sender)") req.Header.Set("Content-Type", "application/json") - if cli == nil { - cli = http.DefaultClient - } - resp, err := cli.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { return mautrix.HTTPError{ Request: req, diff --git a/bridge/websocket.go b/bridge/websocket.go new file mode 100644 index 00000000..44a3d8d8 --- /dev/null +++ b/bridge/websocket.go @@ -0,0 +1,163 @@ +package bridge + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/appservice" +) + +const defaultReconnectBackoff = 2 * time.Second +const maxReconnectBackoff = 2 * time.Minute +const reconnectBackoffReset = 5 * time.Minute + +func (br *Bridge) startWebsocket(wg *sync.WaitGroup) { + log := br.ZLog.With().Str("action", "appservice websocket").Logger() + var wgOnce sync.Once + onConnect := func() { + wssBr, ok := br.Child.(WebsocketStartingBridge) + if ok { + wssBr.OnWebsocketConnect() + } + if br.latestState != nil { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + br.latestState.Timestamp = jsontime.UnixNow() + err := br.SendBridgeState(ctx, br.latestState) + if err != nil { + log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") + } else { + log.Debug().Any("bridge_state", br.latestState).Msg("Resent bridge state after websocket reconnect") + } + }() + } + wgOnce.Do(wg.Done) + select { + case br.wsStarted <- struct{}{}: + default: + } + } + reconnectBackoff := defaultReconnectBackoff + lastDisconnect := time.Now().UnixNano() + br.wsStopped = make(chan struct{}) + defer func() { + log.Debug().Msg("Appservice websocket loop finished") + close(br.wsStopped) + }() + addr := br.Config.Homeserver.WSProxy + if addr == "" { + addr = br.Config.Homeserver.Address + } + for { + err := br.AS.StartWebsocket(addr, onConnect) + if errors.Is(err, appservice.ErrWebsocketManualStop) { + return + } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { + log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") + br.ManualStop(0) + return + } else if err != nil { + log.Err(err).Msg("Error in appservice websocket") + } + if br.Stopping { + return + } + now := time.Now().UnixNano() + if lastDisconnect+reconnectBackoffReset.Nanoseconds() < now { + reconnectBackoff = defaultReconnectBackoff + } else { + reconnectBackoff *= 2 + if reconnectBackoff > maxReconnectBackoff { + reconnectBackoff = maxReconnectBackoff + } + } + lastDisconnect = now + log.Info(). + Int("backoff_seconds", int(reconnectBackoff.Seconds())). + Msg("Websocket disconnected, reconnecting...") + select { + case <-br.wsShortCircuitReconnectBackoff: + log.Debug().Msg("Reconnect backoff was short-circuited") + case <-time.After(reconnectBackoff): + } + if br.Stopping { + return + } + } +} + +type wsPingData struct { + Timestamp int64 `json:"timestamp"` +} + +func (br *Bridge) PingServer() (start, serverTs, end time.Time) { + if !br.Websocket { + panic(fmt.Errorf("PingServer called without websocket enabled")) + } + if !br.AS.HasWebsocket() { + br.ZLog.Debug().Msg("Received server ping request, but no websocket connected. Trying to short-circuit backoff sleep") + select { + case br.wsShortCircuitReconnectBackoff <- struct{}{}: + default: + br.ZLog.Warn().Msg("Failed to ping websocket: not connected and no backoff?") + return + } + select { + case <-br.wsStarted: + case <-time.After(15 * time.Second): + if !br.AS.HasWebsocket() { + br.ZLog.Warn().Msg("Failed to ping websocket: didn't connect after 15 seconds of waiting") + return + } + } + } + start = time.Now() + var resp wsPingData + br.ZLog.Debug().Msg("Pinging appservice websocket") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{ + Command: "ping", + Data: &wsPingData{Timestamp: start.UnixMilli()}, + }, &resp) + end = time.Now() + if err != nil { + br.ZLog.Warn().Err(err).Dur("duration", end.Sub(start)).Msg("Websocket ping returned error") + br.AS.StopWebsocket(fmt.Errorf("websocket ping returned error in %s: %w", end.Sub(start), err)) + } else { + serverTs = time.Unix(0, resp.Timestamp*int64(time.Millisecond)) + br.ZLog.Debug(). + Dur("duration", end.Sub(start)). + Dur("req_duration", serverTs.Sub(start)). + Dur("resp_duration", end.Sub(serverTs)). + Msg("Websocket ping returned success") + } + return +} + +func (br *Bridge) websocketServerPinger() { + interval := time.Duration(br.Config.Homeserver.WSPingInterval) * time.Second + clock := time.NewTicker(interval) + defer func() { + br.ZLog.Info().Msg("Stopping websocket pinger") + clock.Stop() + }() + br.ZLog.Info().Dur("interval_duration", interval).Msg("Starting websocket pinger") + for { + select { + case <-clock.C: + br.PingServer() + case <-br.wsStopPinger: + return + } + if br.Stopping { + return + } + } +} diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 61318d94..6b9242c2 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -9,7 +9,6 @@ package bridgev2 import ( "context" "fmt" - "runtime/debug" "time" "github.com/rs/zerolog" @@ -38,10 +37,8 @@ func (br *Bridge) RunBackfillQueue() { return } ctx, cancel := context.WithCancel(log.WithContext(context.Background())) - br.stopBackfillQueue.Clear() - stopChan := br.stopBackfillQueue.GetChan() go func() { - <-stopChan + <-br.stopBackfillQueue cancel() }() batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second @@ -63,7 +60,7 @@ func (br *Bridge) RunBackfillQueue() { } } noTasksFoundCount = 0 - case <-stopChan: + case <-br.stopBackfillQueue: if !timer.Stop() { select { case <-timer.C: @@ -80,30 +77,17 @@ func (br *Bridge) RunBackfillQueue() { time.Sleep(BackfillQueueErrorBackoff) continue } else if backfillTask != nil { - br.DoBackfillTask(ctx, backfillTask) + br.doBackfillTask(ctx, backfillTask) noTasksFoundCount = 0 } } } -func (br *Bridge) DoBackfillTask(ctx context.Context, task *database.BackfillTask) { +func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTask) { log := zerolog.Ctx(ctx).With(). Object("portal_key", task.PortalKey). Str("login_id", string(task.UserLoginID)). Logger() - defer func() { - err := recover() - if err != nil { - logEvt := log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()) - if realErr, ok := err.(error); ok { - logEvt = logEvt.Err(realErr) - } else { - logEvt = logEvt.Any(zerolog.ErrorFieldName, err) - } - logEvt.Msg("Panic in backfill queue") - } - }() ctx = log.WithContext(ctx) err := br.DB.BackfillTask.MarkDispatched(ctx, task) if err != nil { @@ -169,30 +153,20 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac login, err := br.GetExistingUserLoginByID(ctx, task.UserLoginID) if err != nil { return false, fmt.Errorf("failed to get user login for backfill task: %w", err) - } else if login == nil || !login.Client.IsLoggedIn() { - if login == nil { - log.Warn().Msg("User login not found for backfill task") - } else { - log.Warn().Msg("User login not logged in for backfill task") - } + } else if login == nil { + log.Warn().Msg("User login not found for backfill task") logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { return false, fmt.Errorf("failed to get user portals for backfill task: %w", err) } else if len(logins) == 0 { log.Debug().Msg("No user logins found for backfill task") task.NextDispatchMinTS = database.BackfillNextDispatchNever - if login == nil { - task.UserLoginID = "" - } + task.UserLoginID = "" return false, nil } - if login == nil { - task.UserLoginID = "" - } - foundLogin := false + task.UserLoginID = "" for _, login = range logins { if login.Client.IsLoggedIn() { - foundLogin = true task.UserLoginID = login.ID log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("overridden_login_id", string(login.ID)) @@ -201,7 +175,7 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac break } } - if !foundLogin { + if task.UserLoginID == "" { log.Debug().Msg("No logged in user logins found for backfill task") task.NextDispatchMinTS = database.BackfillNextDispatchNever return false, nil diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 226adc90..aadefb0a 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -9,20 +9,15 @@ package bridgev2 import ( "context" "fmt" - "os" "sync" - "sync/atomic" - "time" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/exsync" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/id" ) @@ -50,17 +45,8 @@ type Bridge struct { ghostsByID map[networkid.UserID]*Ghost cacheLock sync.Mutex - didSplitPortals bool - - Background bool - ExternallyManagedDB bool - stopping atomic.Bool - wakeupBackfillQueue chan struct{} - stopBackfillQueue *exsync.Event - - BackgroundCtx context.Context - cancelBackgroundCtx context.CancelFunc + stopBackfillQueue chan struct{} } func NewBridge( @@ -88,7 +74,7 @@ func NewBridge( ghostsByID: make(map[networkid.UserID]*Ghost), wakeupBackfillQueue: make(chan struct{}), - stopBackfillQueue: exsync.NewEvent(), + stopBackfillQueue: make(chan struct{}), } if br.Config == nil { br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} @@ -114,89 +100,28 @@ func (e DBUpgradeError) Unwrap() error { return e.Err } -func (br *Bridge) Start(ctx context.Context) error { - ctx = br.Log.WithContext(ctx) - err := br.StartConnectors(ctx) +func (br *Bridge) Start() error { + err := br.StartConnectors() if err != nil { return err } - err = br.StartLogins(ctx) + err = br.StartLogins() if err != nil { return err } - go br.PostStart(ctx) return nil } -func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { - br.Background = true - br.stopping.Store(false) - err := br.StartConnectors(ctx) - if err != nil { - return err - } - - if loginID == "" { - br.Log.Info().Msg("No login ID provided to RunOnce, running all logins for 20 seconds") - err = br.StartLogins(ctx) - if err != nil { - return err - } - defer br.StopWithTimeout(5 * time.Second) - select { - case <-time.After(20 * time.Second): - case <-ctx.Done(): - } - return nil - } - - defer br.stop(true, 5*time.Second) - login, err := br.GetExistingUserLoginByID(ctx, loginID) - if err != nil { - return fmt.Errorf("failed to get user login: %w", err) - } else if login == nil { - return ErrNotLoggedIn - } - syncClient, ok := login.Client.(BackgroundSyncingNetworkAPI) - if !ok { - br.Log.Warn().Msg("Network connector doesn't implement background mode, using fallback mechanism for RunOnce") - login.Client.Connect(ctx) - defer login.DisconnectWithTimeout(5 * time.Second) - select { - case <-time.After(20 * time.Second): - case <-ctx.Done(): - } - br.stopping.Store(true) - return nil - } else { - br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") - return syncClient.ConnectBackground(login.Log.WithContext(ctx), params) - } -} - -func (br *Bridge) StartConnectors(ctx context.Context) error { +func (br *Bridge) StartConnectors() error { br.Log.Info().Msg("Starting bridge") - br.stopping.Store(false) - if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { - br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) - br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) - } + ctx := br.Log.WithContext(context.Background()) - if !br.ExternallyManagedDB { - err := br.DB.Upgrade(ctx) - if err != nil { - return DBUpgradeError{Err: err, Section: "main"} - } - } - if !br.Background { - var postMigrate func() - br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx) - if postMigrate != nil { - defer postMigrate() - } + err := br.DB.Upgrade(ctx) + if err != nil { + return DBUpgradeError{Err: err, Section: "main"} } br.Log.Info().Msg("Starting Matrix connector") - err := br.Matrix.Start(ctx) + err = br.Matrix.Start(ctx) if err != nil { return fmt.Errorf("failed to start Matrix connector: %w", err) } @@ -205,144 +130,15 @@ func (br *Bridge) StartConnectors(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to start network connector: %w", err) } - if br.Network.GetCapabilities().DisappearingMessages && !br.Background { + if br.Network.GetCapabilities().DisappearingMessages { go br.DisappearLoop.Start() } return nil } -func (br *Bridge) PostStart(ctx context.Context) { - if br.Background { - return - } - rawBridgeInfoVer := br.DB.KV.Get(ctx, database.KeyBridgeInfoVersion) - bridgeInfoVer, capVer, err := parseBridgeInfoVersion(rawBridgeInfoVer) - if err != nil { - br.Log.Err(err).Str("db_bridge_info_version", rawBridgeInfoVer).Msg("Failed to parse bridge info version") - return - } - expectedBridgeInfoVer, expectedCapVer := br.Network.GetBridgeInfoVersion() - doResendBridgeInfo := bridgeInfoVer != expectedBridgeInfoVer || br.didSplitPortals || br.Config.ResendBridgeInfo - doResendCapabilities := capVer != expectedCapVer || br.didSplitPortals - if doResendBridgeInfo || doResendCapabilities { - br.ResendBridgeInfo(ctx, doResendBridgeInfo, doResendCapabilities) - } - br.DB.KV.Set(ctx, database.KeyBridgeInfoVersion, fmt.Sprintf("%d,%d", expectedBridgeInfoVer, expectedCapVer)) -} +func (br *Bridge) StartLogins() error { + ctx := br.Log.WithContext(context.Background()) -func parseBridgeInfoVersion(version string) (info, capabilities int, err error) { - _, err = fmt.Sscanf(version, "%d,%d", &info, &capabilities) - if version == "" { - err = nil - } - return -} - -func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps bool) { - log := zerolog.Ctx(ctx).With().Str("action", "resend bridge info").Logger() - portals, err := br.GetAllPortalsWithMXID(ctx) - if err != nil { - log.Err(err).Msg("Failed to get portals") - return - } - for _, portal := range portals { - if resendInfo { - portal.UpdateBridgeInfo(ctx) - } - if resendCaps { - logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) - if err != nil { - log.Err(err). - Stringer("room_id", portal.MXID). - Object("portal_key", portal.PortalKey). - Msg("Failed to get user logins in portal") - } else { - found := false - for _, login := range logins { - if portal.CapState.ID == "" || login.ID == portal.CapState.Source { - portal.UpdateCapabilities(ctx, login, true) - found = true - } - } - if !found && len(logins) > 0 { - portal.CapState.Source = "" - portal.UpdateCapabilities(ctx, logins[0], true) - } else if !found { - log.Warn(). - Stringer("room_id", portal.MXID). - Object("portal_key", portal.PortalKey). - Msg("No user login found to update capabilities") - } - } - } - } - log.Info(). - Bool("capabilities", resendCaps). - Bool("info", resendInfo). - Msg("Resent bridge info to all portals") -} - -func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { - log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger() - ctx = log.WithContext(ctx) - if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" { - return false, nil - } - affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals") - os.Exit(31) - return false, nil - } - log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") - affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx) - if err != nil { - log.Err(err).Msg("Failed to fix parent portals after split portal migration") - os.Exit(31) - return false, nil - } - log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration") - withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") - os.Exit(31) - return false, nil - } - var roomsToDelete []id.RoomID - log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver") - for _, portal := range withoutReceiver { - if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil { - log.Err(err). - Str("portal_id", string(portal.ID)). - Stringer("mxid", portal.MXID). - Msg("Failed to delete portal database row that failed to migrate") - } else if portal.MXID != "" { - log.Debug(). - Str("portal_id", string(portal.ID)). - Stringer("mxid", portal.MXID). - Msg("Marked portal room for deletion from homeserver") - roomsToDelete = append(roomsToDelete, portal.MXID) - } else { - log.Debug(). - Str("portal_id", string(portal.ID)). - Msg("Deleted portal row with no Matrix room") - } - } - br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") - log.Info().Msg("Finished split portal migration successfully") - return affected > 0, func() { - for _, roomID := range roomsToDelete { - if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil { - log.Err(err). - Stringer("mxid", roomID). - Msg("Failed to delete portal room that failed to migrate") - } - } - log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate") - } -} - -func (br *Bridge) StartLogins(ctx context.Context) error { userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) if err != nil { return fmt.Errorf("failed to get users with logins: %w", err) @@ -355,10 +151,13 @@ func (br *Bridge) StartLogins(ctx context.Context) error { if err != nil { br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user") } else { - for _, login := range user.GetUserLogins() { + for _, login := range user.GetCachedUserLogins() { startedAny = true br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") - login.Client.Connect(login.Log.WithContext(ctx)) + err = login.Client.Connect(login.Log.WithContext(ctx)) + if err != nil { + br.Log.Err(err).Msg("Failed to connect existing client") + } } } } @@ -366,93 +165,27 @@ func (br *Bridge) StartLogins(ctx context.Context) error { br.Log.Info().Msg("No user logins found") br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } - if !br.Background { - go br.RunBackfillQueue() - } + go br.RunBackfillQueue() br.Log.Info().Msg("Bridge started") return nil } -func (br *Bridge) ResetNetworkConnections() { - nrn, ok := br.Network.(NetworkResettingNetwork) - if ok { - br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections") - nrn.ResetNetworkConnections() - return - } - - br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually") - for _, login := range br.GetAllCachedUserLogins() { - login.Log.Debug().Msg("Disconnecting and recreating client for network reset") - ctx := login.Log.WithContext(br.BackgroundCtx) - login.Client.Disconnect() - err := login.recreateClient(ctx) - if err != nil { - login.Log.Err(err).Msg("Failed to recreate client during network reset") - login.BridgeState.Send(status.BridgeState{ - StateEvent: status.StateUnknownError, - Error: "bridgev2-network-reset-fail", - Info: map[string]any{"go_error": err.Error()}, - }) - } else { - login.Client.Connect(ctx) - } - } - br.Log.Info().Msg("Finished resetting all user logins") -} - -func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings { - mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings) - if ok { - return mchs.GetHTTPClientSettings() - } - return exhttp.SensibleClientSettings -} - -func (br *Bridge) IsStopping() bool { - return br.stopping.Load() -} - func (br *Bridge) Stop() { - br.stop(false, 0) -} - -func (br *Bridge) StopWithTimeout(timeout time.Duration) { - br.stop(false, timeout) -} - -func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { br.Log.Info().Msg("Shutting down bridge") - br.stopping.Store(true) - br.DisappearLoop.Stop() - br.stopBackfillQueue.Set() - br.Matrix.PreStop() - if !isRunOnce { - br.cacheLock.Lock() - var wg sync.WaitGroup - wg.Add(len(br.userLoginsByID)) - for _, login := range br.userLoginsByID { - go func() { - login.DisconnectWithTimeout(timeout) - wg.Done() - }() - } - br.cacheLock.Unlock() - wg.Wait() - } + close(br.stopBackfillQueue) br.Matrix.Stop() - if br.cancelBackgroundCtx != nil { - br.cancelBackgroundCtx() + br.cacheLock.Lock() + var wg sync.WaitGroup + wg.Add(len(br.userLoginsByID)) + for _, login := range br.userLoginsByID { + go login.Disconnect(wg.Done) } - if stopNet, ok := br.Network.(StoppableNetwork); ok { - stopNet.Stop() - } - if !br.ExternallyManagedDB { - err := br.DB.Close() - if err != nil { - br.Log.Warn().Err(err).Msg("Failed to close database") - } + wg.Wait() + br.cacheLock.Unlock() + err := br.DB.Close() + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to close database") } br.Log.Info().Msg("Shutdown complete") } diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index f709c8e0..9ff333e9 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -8,9 +8,9 @@ package bridgeconfig import ( "fmt" + "html/template" "regexp" "strings" - "text/template" "go.mau.fi/util/exerrors" "go.mau.fi/util/random" @@ -79,18 +79,12 @@ func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registr registration.SoruEphemeralEvents = asc.EphemeralEvents } -func (ec *EncryptionConfig) applyUnstableFlags(registration *appservice.Registration) { - registration.MSC4190 = ec.MSC4190 - registration.MSC3202 = ec.Appservice -} - // GenerateRegistration generates a registration file for the homeserver. func (config *Config) GenerateRegistration() *appservice.Registration { registration := appservice.CreateRegistration() config.AppService.HSToken = registration.ServerToken config.AppService.ASToken = registration.AppToken config.AppService.copyToRegistration(registration) - config.Encryption.applyUnstableFlags(registration) registration.SenderLocalpart = random.String(32) botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", @@ -109,7 +103,6 @@ func (config *Config) MakeAppService() *appservice.AppService { as.Host.Hostname = config.AppService.Hostname as.Host.Port = config.AppService.Port as.Registration = config.AppService.GetRegistration() - config.Encryption.applyUnstableFlags(as.Registration) return as } diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index eedae1e8..44d2d588 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -14,11 +14,6 @@ type BackfillConfig struct { Threads BackfillThreadsConfig `yaml:"threads"` Queue BackfillQueueConfig `yaml:"queue"` - - // Flag to indicate that the creator will not run the backfill queue but will still paginate - // backfill by calling DoBackfillTask directly. Note that this is not used anywhere within - // mautrix-go and exists so bridges can use it to decide when to drop backfill data. - WillPaginateManually bool `yaml:"will_paginate_manually"` } type BackfillThreadsConfig struct { @@ -34,12 +29,10 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } -func (bqc *BackfillQueueConfig) GetOverride(names ...string) int { - for _, name := range names { - override, ok := bqc.MaxBatchesOverride[name] - if ok { - return override - } +func (bqc *BackfillQueueConfig) GetOverride(name string) int { + override, ok := bqc.MaxBatchesOverride[name] + if !ok { + return bqc.MaxBatches } - return bqc.MaxBatches + return override } diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index bd6b9c06..ab97c891 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -7,13 +7,10 @@ package bridgeconfig import ( - "time" - "go.mau.fi/util/dbutil" "go.mau.fi/zeroconfig" "gopkg.in/yaml.v3" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/mediaproxy" ) @@ -24,7 +21,6 @@ type Config struct { Homeserver HomeserverConfig `yaml:"homeserver"` AppService AppserviceConfig `yaml:"appservice"` Matrix MatrixConfig `yaml:"matrix"` - Analytics AnalyticsConfig `yaml:"analytics"` Provisioning ProvisioningConfig `yaml:"provisioning"` PublicMedia PublicMediaConfig `yaml:"public_media"` DirectMedia DirectMediaConfig `yaml:"direct_media"` @@ -33,8 +29,6 @@ type Config struct { Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` - EnvConfigPrefix string `yaml:"env_config_prefix"` - ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } @@ -62,52 +56,30 @@ type CleanupOnLogouts struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` - AsyncEvents bool `yaml:"async_events"` - SplitPortals bool `yaml:"split_portals"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` - BridgeStatusNotices string `yaml:"bridge_status_notices"` - UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` - UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` - BridgeNotices bool `yaml:"bridge_notices"` - TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` - MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` - DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` - CrossRoomReplies bool `yaml:"cross_room_replies"` - OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` - RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` - KickMatrixUsers bool `yaml:"kick_matrix_users"` - CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` - Relay RelayConfig `yaml:"relay"` - Permissions PermissionConfig `yaml:"permissions"` - Backfill BackfillConfig `yaml:"backfill"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` - UploadFileThreshold int64 `yaml:"upload_file_threshold"` - GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"` -} - -type AnalyticsConfig struct { - Token string `yaml:"token"` - URL string `yaml:"url"` - UserID string `yaml:"user_id"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` } type ProvisioningConfig struct { - SharedSecret string `yaml:"shared_secret"` - DebugEndpoints bool `yaml:"debug_endpoints"` - EnableSessionTransfers bool `yaml:"enable_session_transfers"` + Prefix string `yaml:"prefix"` + SharedSecret string `yaml:"shared_secret"` + DebugEndpoints bool `yaml:"debug_endpoints"` } type DirectMediaConfig struct { @@ -117,12 +89,10 @@ type DirectMediaConfig struct { } type PublicMediaConfig struct { - Enabled bool `yaml:"enabled"` - SigningKey string `yaml:"signing_key"` - Expiry int `yaml:"expiry"` - HashLength int `yaml:"hash_length"` - PathPrefix string `yaml:"path_prefix"` - UseDatabase bool `yaml:"use_database"` + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + HashLength int `yaml:"hash_length"` + Expiry int `yaml:"expiry"` } type DoublePuppetConfig struct { diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 934613ca..93a427d3 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -15,9 +15,6 @@ type EncryptionConfig struct { Default bool `yaml:"default"` Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` - MSC4190 bool `yaml:"msc4190"` - MSC4392 bool `yaml:"msc4392"` - SelfSign bool `yaml:"self_sign"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go deleted file mode 100644 index 954a37c3..00000000 --- a/bridgev2/bridgeconfig/legacymigrate.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "fmt" - "net/url" - "os" - "strings" - - up "go.mau.fi/util/configupgrade" -) - -var HackyMigrateLegacyNetworkConfig func(up.Helper) - -func CopyToOtherLocation(helper up.Helper, fieldType up.YAMLType, source, dest []string) { - val, ok := helper.Get(fieldType, source...) - if ok { - helper.Set(fieldType, val, dest...) - } -} - -func CopyMapToOtherLocation(helper up.Helper, source, dest []string) { - val := helper.GetNode(source...) - if val != nil && val.Map != nil { - helper.SetMap(val.Map, dest...) - } -} - -func doMigrateLegacy(helper up.Helper, python bool) { - if HackyMigrateLegacyNetworkConfig == nil { - _, _ = fmt.Fprintln(os.Stderr, "Legacy bridge config detected, but hacky network config migrator is not set") - os.Exit(1) - } - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy bridge config") - - helper.Copy(up.Str, "homeserver", "address") - helper.Copy(up.Str, "homeserver", "domain") - helper.Copy(up.Str, "homeserver", "software") - helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") - helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") - helper.Copy(up.Bool, "homeserver", "async_media") - helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") - helper.Copy(up.Bool, "homeserver", "websocket") - helper.Copy(up.Int, "homeserver", "ping_interval_seconds") - - helper.Copy(up.Str|up.Null, "appservice", "address") - helper.Copy(up.Str|up.Null, "appservice", "hostname") - helper.Copy(up.Int|up.Null, "appservice", "port") - helper.Copy(up.Str, "appservice", "id") - if python { - CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_username"}, []string{"appservice", "bot", "username"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_displayname"}, []string{"appservice", "bot", "displayname"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_avatar"}, []string{"appservice", "bot", "avatar"}) - } else { - helper.Copy(up.Str, "appservice", "bot", "username") - helper.Copy(up.Str, "appservice", "bot", "displayname") - helper.Copy(up.Str, "appservice", "bot", "avatar") - } - helper.Copy(up.Bool, "appservice", "ephemeral_events") - helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Str, "appservice", "as_token") - helper.Copy(up.Str, "appservice", "hs_token") - - helper.Copy(up.Str, "bridge", "command_prefix") - helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") - if oldPM, ok := helper.Get(up.Str, "bridge", "private_chat_portal_meta"); ok && (oldPM == "default" || oldPM == "always") { - helper.Set(up.Bool, "true", "bridge", "private_chat_portal_meta") - } else { - helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") - } - helper.Copy(up.Bool, "bridge", "relay", "enabled") - helper.Copy(up.Bool, "bridge", "relay", "admin_only") - helper.Copy(up.Map, "bridge", "permissions") - - if python { - legacyDB, ok := helper.Get(up.Str, "appservice", "database") - if ok { - if strings.HasPrefix(legacyDB, "postgres") { - parsedDB, err := url.Parse(legacyDB) - if err != nil { - panic(err) - } - q := parsedDB.Query() - if parsedDB.Host == "" && !q.Has("host") { - q.Set("host", "/var/run/postgresql") - } else if !q.Has("sslmode") { - q.Set("sslmode", "disable") - } - parsedDB.RawQuery = q.Encode() - helper.Set(up.Str, parsedDB.String(), "database", "uri") - helper.Set(up.Str, "postgres", "database", "type") - } else { - dbPath := strings.TrimPrefix(strings.TrimPrefix(legacyDB, "sqlite:"), "///") - helper.Set(up.Str, fmt.Sprintf("file:%s?_txlock=immediate", dbPath), "database", "uri") - helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") - } - } - if legacyDBMinSize, ok := helper.Get(up.Int, "appservice", "database_opts", "min_size"); ok { - helper.Set(up.Int, legacyDBMinSize, "database", "max_idle_conns") - } - if legacyDBMaxSize, ok := helper.Get(up.Int, "appservice", "database_opts", "max_size"); ok { - helper.Set(up.Int, legacyDBMaxSize, "database", "max_open_conns") - } - } else { - if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { - helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") - } else { - CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) - } - CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_idle_time"}, []string{"database", "max_conn_idle_time"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_lifetime"}, []string{"database", "max_conn_lifetime"}) - } - - if python { - if usernameTemplate, ok := helper.Get(up.Str, "bridge", "username_template"); ok && strings.Contains(usernameTemplate, "{userid}") { - helper.Set(up.Str, strings.ReplaceAll(usernameTemplate, "{userid}", "{{.}}"), "appservice", "username_template") - } - } else { - CopyToOtherLocation(helper, up.Str, []string{"bridge", "username_template"}, []string{"appservice", "username_template"}) - } - - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_status_events"}, []string{"matrix", "message_status_events"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "delivery_receipts"}, []string{"matrix", "delivery_receipts"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_error_notices"}, []string{"matrix", "message_error_notices"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) - - CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) - - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "double_puppet_allow_discovery"}, []string{"double_puppet", "allow_discovery"}) - CopyMapToOtherLocation(helper, []string{"bridge", "double_puppet_server_map"}, []string{"double_puppet", "servers"}) - CopyMapToOtherLocation(helper, []string{"bridge", "login_shared_secret_map"}, []string{"double_puppet", "secrets"}) - - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow"}, []string{"encryption", "allow"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "default"}, []string{"encryption", "default"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "require"}, []string{"encryption", "require"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "appservice"}, []string{"encryption", "appservice"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow_key_sharing"}, []string{"encryption", "allow_key_sharing"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outbound_on_ack"}, []string{"encryption", "delete_keys", "delete_outbound_on_ack"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "dont_store_outbound"}, []string{"encryption", "delete_keys", "dont_store_outbound"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "ratchet_on_decrypt"}, []string{"encryption", "delete_keys", "ratchet_on_decrypt"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt"}, []string{"encryption", "delete_keys", "delete_fully_used_on_decrypt"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_prev_on_new_session"}, []string{"encryption", "delete_keys", "delete_prev_on_new_session"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_on_device_delete"}, []string{"encryption", "delete_keys", "delete_on_device_delete"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "periodically_delete_expired"}, []string{"encryption", "delete_keys", "periodically_delete_expired"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outdated_inbound"}, []string{"encryption", "delete_keys", "delete_outdated_inbound"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "receive"}, []string{"encryption", "verification_levels", "receive"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "send"}, []string{"encryption", "verification_levels", "send"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "share"}, []string{"encryption", "verification_levels", "share"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "enable_custom"}, []string{"encryption", "rotation", "enable_custom"}) - CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "milliseconds"}, []string{"encryption", "rotation", "milliseconds"}) - CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "messages"}, []string{"encryption", "rotation", "messages"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "disable_device_change_key_rotation"}, []string{"encryption", "rotation", "disable_device_change_key_rotation"}) - - if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating maulogger configs is not supported") - } else if (helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil)) || python { - _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") - } else { - helper.Copy(up.Map, "logging") - } - - HackyMigrateLegacyNetworkConfig(helper) -} diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 9efe068e..e76046f5 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -8,8 +8,6 @@ package bridgeconfig import ( "fmt" - "os" - "strconv" "strings" "gopkg.in/yaml.v3" @@ -24,7 +22,6 @@ type Permissions struct { DoublePuppet bool `yaml:"double_puppet"` Admin bool `yaml:"admin"` ManageRelay bool `yaml:"manage_relay"` - MaxLogins int `yaml:"max_logins"` } type PermissionConfig map[string]*Permissions @@ -41,7 +38,10 @@ func (pc PermissionConfig) IsConfigured() bool { _, hasExampleDomain := pc["example.com"] _, hasExampleUser := pc["@admin:example.com"] exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - return len(pc) > exampleLen + if len(pc) <= exampleLen { + return false + } + return true } func (pc PermissionConfig) Get(userID id.UserID) Permissions { @@ -94,23 +94,6 @@ func (p *Permissions) UnmarshalYAML(perm *yaml.Node) error { case "!!map": err := perm.Decode((*umPerm)(p)) return err - case "!!int": - val, err := strconv.Atoi(perm.Value) - if err != nil { - return fmt.Errorf("invalid permissions level %s", perm.Value) - } - _, _ = fmt.Fprintln(os.Stderr, "Warning: config contains deprecated integer permission values") - // Integer values are deprecated, so they're hardcoded - if val < 5 { - *p = PermissionLevelBlock - } else if val < 10 { - *p = PermissionLevelRelay - } else if val < 100 { - *p = PermissionLevelUser - } else { - *p = PermissionLevelAdmin - } - return nil default: return fmt.Errorf("invalid permissions type %s", perm.Tag) } diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 92515ea0..4eff205d 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -8,6 +8,7 @@ package bridgeconfig import ( "fmt" + "os" up "go.mau.fi/util/configupgrade" "go.mau.fi/util/random" @@ -17,32 +18,16 @@ import ( func doUpgrade(helper up.Helper) { if _, isLegacyConfig := helper.Get(up.Str, "appservice", "database", "uri"); isLegacyConfig { - doMigrateLegacy(helper, false) - return - } else if _, isLegacyPython := helper.Get(up.Str, "appservice", "database"); isLegacyPython { - doMigrateLegacy(helper, true) + doMigrateLegacy(helper) return } helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") - helper.Copy(up.Bool, "bridge", "async_events") - helper.Copy(up.Bool, "bridge", "split_portals") - helper.Copy(up.Bool, "bridge", "resend_bridge_info") - helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") - helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") - helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") - helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") - helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") - helper.Copy(up.List, "bridge", "only_bridge_tags") helper.Copy(up.Bool, "bridge", "mute_only_on_create") - helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") - helper.Copy(up.Bool, "bridge", "cross_room_replies") - helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") - helper.Copy(up.Bool, "bridge", "kick_matrix_users") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") @@ -60,7 +45,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Map, "bridge", "permissions") if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { - fmt.Println("Warning: invalid database type sqlite3 in config. Autocorrecting to sqlite3-fk-wal") helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") } else { helper.Copy(up.Str, "database", "type") @@ -100,13 +84,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "message_error_notices") helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") - helper.Copy(up.Int, "matrix", "upload_file_threshold") - helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info") - - helper.Copy(up.Str|up.Null, "analytics", "token") - helper.Copy(up.Str|up.Null, "analytics", "url") - helper.Copy(up.Str|up.Null, "analytics", "user_id") + helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { sharedSecret := random.String(64) helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret") @@ -114,7 +93,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "provisioning", "shared_secret") } helper.Copy(up.Bool, "provisioning", "debug_endpoints") - helper.Copy(up.Bool, "provisioning", "enable_session_transfers") helper.Copy(up.Bool, "direct_media", "enabled") helper.Copy(up.Str|up.Null, "direct_media", "media_id_prefix") @@ -136,8 +114,6 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Int, "public_media", "expiry") helper.Copy(up.Int, "public_media", "hash_length") - helper.Copy(up.Str|up.Null, "public_media", "path_prefix") - helper.Copy(up.Bool, "public_media", "use_database") helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") @@ -158,13 +134,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "encryption", "default") helper.Copy(up.Bool, "encryption", "require") helper.Copy(up.Bool, "encryption", "appservice") - if val, ok := helper.Get(up.Bool, "appservice", "msc4190"); ok { - helper.Set(up.Bool, val, "encryption", "msc4190") - } else { - helper.Copy(up.Bool, "encryption", "msc4190") - } - helper.Copy(up.Bool, "encryption", "msc4392") - helper.Copy(up.Bool, "encryption", "self_sign") helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { helper.Set(up.Str, random.String(64), "encryption", "pickle_key") @@ -187,15 +156,124 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") - helper.Copy(up.Str|up.Null, "env_config_prefix") - helper.Copy(up.Map, "logging") } +func CopyToOtherLocation(helper up.Helper, fieldType up.YAMLType, source, dest []string) { + val, ok := helper.Get(fieldType, source...) + if ok { + helper.Set(fieldType, val, dest...) + } +} + +func CopyMapToOtherLocation(helper up.Helper, source, dest []string) { + val := helper.GetNode(source...) + if val != nil && val.Map != nil { + helper.SetMap(val.Map, dest...) + } +} + +var HackyMigrateLegacyNetworkConfig func(up.Helper) + +func doMigrateLegacy(helper up.Helper) { + if HackyMigrateLegacyNetworkConfig == nil { + _, _ = fmt.Fprintln(os.Stderr, "Legacy bridge config detected, but hacky network config migrator is not set") + os.Exit(1) + } + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy bridge config") + + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + helper.Copy(up.Str, "homeserver", "software") + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + helper.Copy(up.Str, "appservice", "id") + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + + helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + if oldPM, ok := helper.Get(up.Str, "bridge", "private_chat_portal_meta"); ok && (oldPM == "default" || oldPM == "always") { + helper.Set(up.Bool, "true", "bridge", "private_chat_portal_meta") + } else { + helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") + } + helper.Copy(up.Bool, "bridge", "relay", "enabled") + helper.Copy(up.Bool, "bridge", "relay", "admin_only") + helper.Copy(up.Map, "bridge", "permissions") + + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_idle_time"}, []string{"database", "max_conn_idle_time"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_lifetime"}, []string{"database", "max_conn_lifetime"}) + + CopyToOtherLocation(helper, up.Str, []string{"bridge", "username_template"}, []string{"appservice", "username_template"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_status_events"}, []string{"matrix", "message_status_events"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "delivery_receipts"}, []string{"matrix", "delivery_receipts"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_error_notices"}, []string{"matrix", "message_error_notices"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) + + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "double_puppet_allow_discovery"}, []string{"double_puppet", "allow_discovery"}) + CopyMapToOtherLocation(helper, []string{"bridge", "double_puppet_server_map"}, []string{"double_puppet", "servers"}) + CopyMapToOtherLocation(helper, []string{"bridge", "login_shared_secret_map"}, []string{"double_puppet", "secrets"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow"}, []string{"encryption", "allow"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "default"}, []string{"encryption", "default"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "require"}, []string{"encryption", "require"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "appservice"}, []string{"encryption", "appservice"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow_key_sharing"}, []string{"encryption", "allow_key_sharing"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outbound_on_ack"}, []string{"encryption", "delete_keys", "delete_outbound_on_ack"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "dont_store_outbound"}, []string{"encryption", "delete_keys", "dont_store_outbound"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "ratchet_on_decrypt"}, []string{"encryption", "delete_keys", "ratchet_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt"}, []string{"encryption", "delete_keys", "delete_fully_used_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_prev_on_new_session"}, []string{"encryption", "delete_keys", "delete_prev_on_new_session"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_on_device_delete"}, []string{"encryption", "delete_keys", "delete_on_device_delete"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "periodically_delete_expired"}, []string{"encryption", "delete_keys", "periodically_delete_expired"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outdated_inbound"}, []string{"encryption", "delete_keys", "delete_outdated_inbound"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "receive"}, []string{"encryption", "verification_levels", "receive"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "send"}, []string{"encryption", "verification_levels", "send"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "share"}, []string{"encryption", "verification_levels", "share"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "enable_custom"}, []string{"encryption", "rotation", "enable_custom"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "milliseconds"}, []string{"encryption", "rotation", "milliseconds"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "messages"}, []string{"encryption", "rotation", "messages"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "disable_device_change_key_rotation"}, []string{"encryption", "rotation", "disable_device_change_key_rotation"}) + + if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating maulogger configs is not supported") + } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") + } else { + helper.Copy(up.Map, "logging") + } + + HackyMigrateLegacyNetworkConfig(helper) +} + var SpacedBlocks = [][]string{ {"bridge"}, - {"bridge", "bridge_matrix_leave"}, - {"bridge", "cleanup_on_logout"}, {"bridge", "relay"}, {"bridge", "permissions"}, {"database"}, @@ -209,14 +287,12 @@ var SpacedBlocks = [][]string{ {"appservice", "as_token"}, {"appservice", "username_template"}, {"matrix"}, - {"analytics"}, {"provisioning"}, {"public_media"}, {"direct_media"}, {"backfill"}, {"double_puppet"}, {"encryption"}, - {"env_config_prefix"}, {"logging"}, } diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 96d9fd5c..1cd6b0c5 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -8,37 +8,20 @@ package bridgev2 import ( "context" - "fmt" - "math/rand/v2" "runtime/debug" - "sync/atomic" "time" "github.com/rs/zerolog" - "go.mau.fi/util/exfmt" - "maunium.net/go/mautrix/bridgev2/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/bridge/status" ) -var CatchBridgeStateQueuePanics = true - type BridgeStateQueue struct { prevUnsent *status.BridgeState prevSent *status.BridgeState - errorSent bool ch chan status.BridgeState bridge *Bridge - login *UserLogin - - firstTransientDisconnect time.Time - cancelScheduledNotice atomic.Pointer[context.CancelFunc] - - stopChan chan struct{} - stopReconnect atomic.Pointer[context.CancelFunc] - - unknownErrorReconnects int + user status.StandaloneCustomBridgeStateFiller } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -58,221 +41,51 @@ func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { } } -func (br *Bridge) NewBridgeStateQueue(login *UserLogin) *BridgeStateQueue { +func (br *Bridge) NewBridgeStateQueue(user status.StandaloneCustomBridgeStateFiller) *BridgeStateQueue { bsq := &BridgeStateQueue{ - ch: make(chan status.BridgeState, 10), - stopChan: make(chan struct{}), - bridge: br, - login: login, + ch: make(chan status.BridgeState, 10), + bridge: br, + user: user, } go bsq.loop() return bsq } func (bsq *BridgeStateQueue) Destroy() { - close(bsq.stopChan) close(bsq.ch) - bsq.StopUnknownErrorReconnect() -} - -func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { - if bsq == nil { - return - } - if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { - (*cancelFn)() - } - if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil { - (*cancelFn)() - } } func (bsq *BridgeStateQueue) loop() { - if CatchBridgeStateQueuePanics { - defer func() { - err := recover() - if err != nil { - bsq.login.Log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() - } + defer func() { + err := recover() + if err != nil { + bsq.bridge.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() for state := range bsq.ch { bsq.immediateSendBridgeState(state) } } -func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { - log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() - ctx := log.WithContext(bsq.bridge.BackgroundCtx) - if !bsq.waitForTransientDisconnectReconnect(ctx) { - return - } - prevUnsent := bsq.GetPrevUnsent() - prev := bsq.GetPrev() - if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent || - prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect { - log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice") - return - } - log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice") - bsq.sendNotice(ctx, triggeredBy, true) -} - -func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) { - noticeConfig := bsq.bridge.Config.BridgeStatusNotices - isError := state.StateEvent == status.StateBadCredentials || - state.StateEvent == status.StateUnknownError || - state.UserAction == status.UserActionOpenNative || - (isDelayed && state.StateEvent == status.StateTransientDisconnect) - sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && - (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) - if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError { - bsq.firstTransientDisconnect = time.Time{} - } - if !sendNotice { - if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { - if bsq.firstTransientDisconnect.IsZero() { - bsq.firstTransientDisconnect = time.Now() - } - go bsq.scheduleNotice(state) - } - return - } - managementRoom, err := bsq.login.User.GetManagementRoom(ctx) - if err != nil { - bsq.login.Log.Err(err).Msg("Failed to get management room") - return - } - name := bsq.login.RemoteName - if name == "" { - name = fmt.Sprintf("`%s`", bsq.login.ID) - } - message := fmt.Sprintf("State update for %s: `%s`", name, state.StateEvent) - if state.Error != "" { - message += fmt.Sprintf(" (`%s`)", state.Error) - } - if isDelayed { - message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay)) - } - if state.Message != "" { - message += fmt.Sprintf(": %s", state.Message) - } - content := format.RenderMarkdown(message, true, false) - if !isError { - content.MsgType = event.MsgNotice - } - _, err = bsq.bridge.Bot.SendMessage(ctx, managementRoom, event.EventMessage, &event.Content{ - Parsed: content, - Raw: map[string]any{ - "fi.mau.bridge_state": state, - }, - }, nil) - if err != nil { - bsq.login.Log.Err(err).Msg("Failed to send bridge state notice") - } else { - bsq.errorSent = isError - } -} - -func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeState) { - log := bsq.login.Log.With().Str("action", "unknown error reconnect").Logger() - ctx := log.WithContext(bsq.bridge.BackgroundCtx) - if !bsq.waitForUnknownErrorReconnect(ctx) { - return - } - prevUnsent := bsq.GetPrevUnsent() - prev := bsq.GetPrev() - if triggeredBy.Timestamp != prev.Timestamp { - log.Debug().Msg("Not reconnecting as a new bridge state was sent after the unknown error") - return - } else if len(bsq.ch) > 0 { - log.Warn().Msg("Not reconnecting as there are unsent bridge states") - return - } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { - log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") - return - } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects { - log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached") - return - } - bsq.unknownErrorReconnects++ - log.Info(). - Int("reconnect_num", bsq.unknownErrorReconnects). - Msg("Disconnecting and reconnecting login due to unknown error") - bsq.login.Disconnect() - log.Debug().Msg("Disconnection finished, recreating client and reconnecting") - err := bsq.login.recreateClient(ctx) - if err != nil { - log.Err(err).Msg("Failed to recreate client after unknown error") - return - } - bsq.login.Client.Connect(ctx) - log.Debug().Msg("Reconnection finished") -} - -func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) bool { - reconnectIn := bsq.bridge.Config.UnknownErrorAutoReconnect - // Don't allow too low values - if reconnectIn < 1*time.Minute { - return false - } - reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) - return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect) -} - -const TransientDisconnectNoticeDelay = 3 * time.Minute - -func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool { - timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay)) - zerolog.Ctx(ctx).Trace(). - Stringer("duration", timeUntilSchedule). - Msg("Waiting before sending notice about transient disconnect") - return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice) -} - -func (bsq *BridgeStateQueue) waitForReconnect( - ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc], -) bool { - cancelCtx, cancel := context.WithCancel(ctx) - defer cancel() - if oldCancel := ptr.Swap(&cancel); oldCancel != nil { - (*oldCancel)() - } - select { - case <-time.After(reconnectIn): - return ptr.CompareAndSwap(&cancel, nil) - case <-cancelCtx.Done(): - return false - case <-bsq.stopChan: - return false - } -} - func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { - if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { - bsq.login.Log.Debug(). - Str("state_event", string(state.StateEvent)). - Msg("Not sending bridge state as it's a duplicate") - return - } - if state.StateEvent == status.StateUnknownError { - go bsq.unknownErrorReconnect(state) - } - - ctx := bsq.login.Log.WithContext(context.Background()) - bsq.sendNotice(ctx, state, false) - retryIn := 2 for { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { + bsq.bridge.Log.Debug(). + Str("state_event", string(state.StateEvent)). + Msg("Not sending bridge state as it's a duplicate") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) err := bsq.bridge.Matrix.SendBridgeStatus(ctx, &state) cancel() if err != nil { - bsq.login.Log.Warn().Err(err). + bsq.bridge.Log.Warn().Err(err). Int("retry_in_seconds", retryIn). Msg("Failed to update bridge state") time.Sleep(time.Duration(retryIn) * time.Second) @@ -282,7 +95,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) } } else { bsq.prevSent = &state - bsq.login.Log.Debug(). + bsq.bridge.Log.Debug(). Any("bridge_state", state). Msg("Sent new bridge state") return @@ -295,11 +108,11 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { return } - state = state.Fill(bsq.login) + state = state.Fill(bsq.user) bsq.prevUnsent = &state if len(bsq.ch) >= 8 { - bsq.login.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") + bsq.bridge.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") select { case <-bsq.ch: default: @@ -308,7 +121,7 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { select { case bsq.ch <- state: default: - bsq.login.Log.Error().Msg("Bridge state queue is full, dropped new state") + bsq.bridge.Log.Error().Msg("Bridge state queue is full, dropped new state") } } diff --git a/bridgev2/commands/cleanup.go b/bridgev2/commands/cleanup.go index dc21a16e..f8ad1d23 100644 --- a/bridgev2/commands/cleanup.go +++ b/bridgev2/commands/cleanup.go @@ -55,43 +55,3 @@ var CommandDeleteAllPortals = &FullHandler{ }, RequiresAdmin: true, } - -var CommandSetManagementRoom = &FullHandler{ - Func: func(ce *Event) { - if ce.User.ManagementRoom == ce.RoomID { - ce.Reply("This room is already your management room") - return - } else if ce.Portal != nil { - ce.Reply("This is a portal room: you can't set this as your management room") - return - } - members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) - if err != nil { - ce.Log.Err(err).Msg("Failed to get room members to check if room can be a management room") - ce.Reply("Failed to get room members") - return - } - _, hasBot := members[ce.Bot.GetMXID()] - if !hasBot { - // This reply will probably fail, but whatever - ce.Reply("The bridge bot must be in the room to set it as your management room") - return - } else if len(members) != 2 { - ce.Reply("Your management room must not have any members other than you and the bridge bot") - return - } - ce.User.ManagementRoom = ce.RoomID - err = ce.User.Save(ce.Ctx) - if err != nil { - ce.Log.Err(err).Msg("Failed to save management room") - ce.Reply("Failed to save management room") - } else { - ce.Reply("Management room updated") - } - }, - Name: "set-management-room", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Mark this room as your management room", - }, -} diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index 1cae98fe..d00697ee 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -7,13 +7,10 @@ package commands import ( - "encoding/json" "strings" - "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" ) var CommandRegisterPush = &FullHandler{ @@ -60,66 +57,4 @@ var CommandRegisterPush = &FullHandler{ }, RequiresAdmin: true, RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], -} - -var CommandSendAccountData = &FullHandler{ - Func: func(ce *Event) { - if len(ce.Args) < 2 { - ce.Reply("Usage: `$cmdprefix debug-account-data ") - return - } - var content event.Content - evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType} - ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0])) - err := json.Unmarshal([]byte(ce.RawArgs), &content) - if err != nil { - ce.Reply("Failed to parse JSON: %v", err) - return - } - err = content.ParseRaw(evtType) - if err != nil { - ce.Reply("Failed to deserialize content: %v", err) - return - } - res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{ - Sender: ce.User.MXID, - Type: evtType, - Timestamp: time.Now().UnixMilli(), - RoomID: ce.RoomID, - Content: content, - }) - ce.Reply("Result: %+v", res) - }, - Name: "debug-account-data", - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Send a room account data event to the bridge", - Args: "<_type_> <_content_>", - }, - RequiresAdmin: true, - RequiresPortal: true, - RequiresLogin: true, -} - -var CommandResetNetwork = &FullHandler{ - Func: func(ce *Event) { - if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") { - nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork) - if ok { - nrn.ResetHTTPTransport() - } else { - ce.Reply("Network connector does not support resetting HTTP transport") - } - } - ce.Bridge.ResetNetworkConnections() - ce.React("✅️") - }, - Name: "debug-reset-network", - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Reset network connections to the remote network", - Args: "[--reset-transport]", - }, - RequiresAdmin: true, } diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index 88ba9698..258ae2f0 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "strings" - "time" "github.com/rs/zerolog" @@ -24,21 +23,20 @@ import ( // Event stores all data which might be used to handle commands type Event struct { - Bot bridgev2.MatrixAPI - Bridge *bridgev2.Bridge - Portal *bridgev2.Portal - Processor *Processor - Handler MinimalCommandHandler - RoomID id.RoomID - OrigRoomID id.RoomID - EventID id.EventID - User *bridgev2.User - Command string - Args []string - RawArgs string - ReplyTo id.EventID - Ctx context.Context - Log *zerolog.Logger + Bot bridgev2.MatrixAPI + Bridge *bridgev2.Bridge + Portal *bridgev2.Portal + Processor *Processor + Handler MinimalCommandHandler + RoomID id.RoomID + EventID id.EventID + User *bridgev2.User + Command string + Args []string + RawArgs string + ReplyTo id.EventID + Ctx context.Context + Log *zerolog.Logger MessageStatus *bridgev2.MessageStatus } @@ -57,15 +55,15 @@ func (ce *Event) Reply(msg string, args ...any) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) if err != nil { - ce.Log.Err(err).Msg("Failed to reply to command") + ce.Log.Err(err).Msgf("Failed to reply to command") } } // React sends a reaction to the command. func (ce *Event) React(key string) { - _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventReaction, &event.Content{ + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ Type: event.RelAnnotation, @@ -75,26 +73,27 @@ func (ce *Event) React(key string) { }, }, nil) if err != nil { - ce.Log.Err(err).Msg("Failed to react to command") + ce.Log.Err(err).Msgf("Failed to react to command") } } // Redact redacts the command. func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventRedaction, &event.Content{ + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: ce.EventID, }, }, nil) if err != nil { - ce.Log.Err(err).Msg("Failed to redact command") + ce.Log.Err(err).Msgf("Failed to redact command") } } // MarkRead marks the command event as read. func (ce *Event) MarkRead() { - err := ce.Bot.MarkRead(ce.Ctx, ce.RoomID, ce.EventID, time.Now()) - if err != nil { - ce.Log.Err(err).Msg("Failed to mark command as read") - } + // TODO + //err := ce.Bot.SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) + //if err != nil { + // ce.Log.Err(err).Msgf("Failed to mark command as read") + //} } diff --git a/bridgev2/commands/handler.go b/bridgev2/commands/handler.go index 672c81dc..c1daf1af 100644 --- a/bridgev2/commands/handler.go +++ b/bridgev2/commands/handler.go @@ -7,7 +7,6 @@ package commands import ( - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" ) @@ -38,18 +37,6 @@ type AliasedCommandHandler interface { GetAliases() []string } -func NetworkAPIImplements[T bridgev2.NetworkAPI](val bridgev2.NetworkAPI) bool { - _, ok := val.(T) - return ok -} - -func NetworkConnectorImplements[T bridgev2.NetworkConnector](val bridgev2.NetworkConnector) bool { - _, ok := val.(T) - return ok -} - -type ImplementationChecker[T any] func(val T) bool - type FullHandler struct { Func func(*Event) @@ -62,9 +49,6 @@ type FullHandler struct { RequiresLogin bool RequiresEventLevel event.Type RequiresLoginPermission bool - - NetworkAPI ImplementationChecker[bridgev2.NetworkAPI] - NetworkConnector ImplementationChecker[bridgev2.NetworkConnector] } func (fh *FullHandler) GetHelp() HelpMeta { @@ -80,15 +64,9 @@ func (fh *FullHandler) GetAliases() []string { return fh.Aliases } -func (fh *FullHandler) ImplementationsFulfilled(ce *Event) bool { - // TODO add dedicated method to get an empty NetworkAPI instead of getting default login - client := ce.User.GetDefaultLogin() - return (fh.NetworkAPI == nil || client == nil || fh.NetworkAPI(client.Client)) && - (fh.NetworkConnector == nil || fh.NetworkConnector(ce.Bridge.Network)) -} - func (fh *FullHandler) ShowInHelp(ce *Event) bool { - return fh.ImplementationsFulfilled(ce) && (!fh.RequiresAdmin || ce.User.Permissions.Admin) + return true + //return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 96d62d3e..df94c6ba 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -10,18 +10,16 @@ import ( "context" "encoding/json" "fmt" - "html" "net/url" "regexp" - "slices" "strings" "github.com/skip2/go-qrcode" "go.mau.fi/util/curl" + "golang.org/x/net/html" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -37,17 +35,6 @@ var CommandLogin = &FullHandler{ RequiresLoginPermission: true, } -var CommandRelogin = &FullHandler{ - Func: fnLogin, - Name: "relogin", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Re-authenticate an existing login", - Args: "<_login ID_> [_flow ID_]", - }, - RequiresLoginPermission: true, -} - func formatFlowsReply(flows []bridgev2.LoginFlow) string { var buf strings.Builder for _, flow := range flows { @@ -57,28 +44,6 @@ func formatFlowsReply(flows []bridgev2.LoginFlow) string { } func fnLogin(ce *Event) { - var reauth *bridgev2.UserLogin - if ce.Command == "relogin" { - if len(ce.Args) == 0 { - ce.Reply("Usage: `$cmdprefix relogin [_flow ID_]`\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) - return - } - reauth = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) - if reauth == nil { - ce.Reply("Login `%s` not found", ce.Args[0]) - return - } - ce.Args = ce.Args[1:] - } - if reauth == nil && ce.User.HasTooManyLogins() { - ce.Reply( - "You have reached the maximum number of logins (%d). "+ - "Please logout from an existing login before creating a new one. "+ - "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.", - ce.User.Permissions.MaxLogins, - ) - return - } flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { @@ -91,17 +56,13 @@ func fnLogin(ce *Event) { } } if chosenFlowID == "" { - ce.Reply("Invalid login flow `%s`. Available options:\n\n%s", inputFlowID, formatFlowsReply(flows)) + ce.Reply("Invalid login flow `%s`. Available options:\n\n%s", ce.Args[0], formatFlowsReply(flows)) return } } else if len(flows) == 1 { chosenFlowID = flows[0].ID } else { - if reauth != nil { - ce.Reply("Please specify a login flow, e.g. `relogin %s %s`.\n\n%s", reauth.ID, flows[0].ID, formatFlowsReply(flows)) - } else { - ce.Reply("Please specify a login flow, e.g. `login %s`.\n\n%s", flows[0].ID, formatFlowsReply(flows)) - } + ce.Reply("Please specify a login flow, e.g. `login %s`.\n\n%s", flows[0].ID, formatFlowsReply(flows)) return } @@ -110,22 +71,15 @@ func fnLogin(ce *Event) { ce.Reply("Failed to prepare login process: %v", err) return } - overridable, ok := login.(bridgev2.LoginProcessWithOverride) - var nextStep *bridgev2.LoginStep - if ok && reauth != nil { - nextStep, err = overridable.StartWithOverride(ce.Ctx, reauth) - } else { - nextStep, err = login.Start(ce.Ctx) - } + nextStep, err := login.Start(ce.Ctx) if err != nil { ce.Reply("Failed to start login: %v", err) return } - ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process") nextStep = checkLoginCommandDirectParams(ce, login, nextStep) if nextStep != nil { - doLoginStep(ce, login, nextStep, reauth) + doLoginStep(ce, login, nextStep) } } @@ -150,7 +104,6 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS return nil } input := make(map[string]string) - var shouldRedact bool for i, param := range nextStep.UserInputParams.Fields { param.FillDefaultValidate() input[param.ID], err = param.Validate(ce.Args[i]) @@ -158,12 +111,6 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS ce.Reply("Invalid value for %s: %v", param.Name, err) return nil } - if param.Type == bridgev2.LoginInputFieldTypePassword || param.Type == bridgev2.LoginInputFieldTypeToken { - shouldRedact = true - } - } - if shouldRedact { - ce.Redact() } nextStep, err = login.(bridgev2.LoginProcessUserInput).SubmitUserInput(ce.Ctx, input) case bridgev2.LoginStepTypeCookies: @@ -173,14 +120,12 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS } input := make(map[string]string) for i, param := range nextStep.CookiesParams.Fields { - val := maybeURLDecodeCookie(ce.Args[i], ¶m) - if match, _ := regexp.MatchString(param.Pattern, val); !match { - ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", param.ID, val, param.Pattern) + if match, _ := regexp.MatchString(param.Pattern, ce.Args[i]); !match { + ce.Reply("Invalid value for %s: doesn't match regex `%s`", param.ID, param.Pattern) return nil } - input[param.ID] = val + input[param.ID] = ce.Args[i] } - ce.Redact() nextStep, err = login.(bridgev2.LoginProcessCookies).SubmitCookies(ce.Ctx, input) } if err != nil { @@ -195,19 +140,15 @@ type userInputLoginCommandState struct { Login bridgev2.LoginProcessUserInput Data map[string]string RemainingFields []bridgev2.LoginInputDataField - Override *bridgev2.UserLogin } func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { field := uilcs.RemainingFields[0] - parts := []string{fmt.Sprintf("Please enter your %s", field.Name)} if field.Description != "" { - parts = append(parts, field.Description) + ce.Reply("Please enter your %s\n%s", field.Name, field.Description) + } else { + ce.Reply("Please enter your %s", field.Name) } - if len(field.Options) > 0 { - parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `"))) - } - ce.Reply(strings.Join(parts, "\n")) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", @@ -219,7 +160,7 @@ func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { field := uilcs.RemainingFields[0] field.FillDefaultValidate() - if field.Type == bridgev2.LoginInputFieldTypePassword || field.Type == bridgev2.LoginInputFieldTypeToken { + if field.Type == bridgev2.LoginInputFieldTypePassword { ce.Redact() } var err error @@ -236,7 +177,7 @@ func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { if nextStep, err := uilcs.Login.SubmitUserInput(ce.Ctx, uilcs.Data); err != nil { ce.Reply("Failed to submit input: %v", err) } else { - doLoginStep(ce, uilcs.Login, nextStep, uilcs.Override) + doLoginStep(ce, uilcs.Login, nextStep) } } @@ -252,19 +193,14 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return fmt.Errorf("failed to upload image: %w", err) } content := &event.MessageEventContent{ - MsgType: event.MsgImage, - FileName: "qr.png", - URL: qrMXC, - File: qrFile, + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, + Body: qr, Format: event.FormatHTML, FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)), - Info: &event.FileInfo{ - MimeType: "image/png", - Width: qrSizePx, - Height: qrSizePx, - Size: len(qrData), - }, } if *prevEventID != "" { content.SetEdit(*prevEventID) @@ -279,55 +215,18 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return nil } -func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error { - for _, att := range atts { - if att.FileName == "" { - return fmt.Errorf("missing attachment filename") - } - mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType) - if err != nil { - return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err) - } - content := &event.MessageEventContent{ - MsgType: att.Type, - FileName: att.FileName, - URL: mxc, - File: file, - Info: &event.FileInfo{ - MimeType: att.Info.MimeType, - Width: att.Info.Width, - Height: att.Info.Height, - Size: att.Info.Size, - }, - Body: att.FileName, - } - _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) - if err != nil { - return nil - } - } - return nil -} - type contextKey int const ( contextKeyPrevEventID contextKey = iota ) -func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { +func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep) { prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID) if !ok { prevEvent = new(id.EventID) ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent) } - cancelCtx, cancelFunc := context.WithCancel(ce.Ctx) - defer cancelFunc() - StoreCommandState(ce.User, &CommandState{ - Action: "Login", - Cancel: cancelFunc, - }) - defer StoreCommandState(ce.User, nil) switch step.DisplayAndWaitParams.Type { case bridgev2.LoginDisplayTypeQR: err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent) @@ -347,7 +246,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, login.Cancel() return } - nextStep, err := login.Wait(cancelCtx) + nextStep, err := login.Wait(ce.Ctx) // Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited) if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) { _, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ @@ -361,17 +260,15 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, ce.Reply("Login failed: %v", err) return } - doLoginStep(ce, login, nextStep, override) + doLoginStep(ce, login, nextStep) } type cookieLoginCommandState struct { - Login bridgev2.LoginProcessCookies - Data *bridgev2.LoginCookiesParams - Override *bridgev2.UserLogin + Login bridgev2.LoginProcessCookies + Data *bridgev2.LoginCookiesParams } func (clcs *cookieLoginCommandState) prompt(ce *Event) { - ce.Reply("Login URL: <%s>", clcs.Data.URL) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(clcs.submit), Action: "Login", @@ -392,7 +289,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { } reqCookies := make(map[string]string) for _, cookie := range parsed.Cookies() { - reqCookies[cookie.Name], err = url.PathUnescape(cookie.Value) + reqCookies[cookie.Name], err = url.QueryUnescape(cookie.Value) if err != nil { ce.Reply("Failed to parse cookie %s: %v", cookie.Name, err) return @@ -451,12 +348,6 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { ce.Reply("Failed to parse input as JSON: %v", err) return } - for _, field := range clcs.Data.Fields { - val, ok := cookiesInput[field.ID] - if ok { - cookiesInput[field.ID] = maybeURLDecodeCookie(val, &field) - } - } } var missingKeys []string for _, field := range clcs.Data.Fields { @@ -465,7 +356,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { missingKeys = append(missingKeys, field.ID) } if match, _ := regexp.MatchString(field.Pattern, val); !match { - ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", field.ID, val, field.Pattern) + ce.Reply("Invalid value for %s: doesn't match regex `%s`", field.ID, field.Pattern) return } } @@ -479,63 +370,30 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { ce.Reply("Login failed: %v", err) return } - doLoginStep(ce, clcs.Login, nextStep, clcs.Override) + doLoginStep(ce, clcs.Login, nextStep) } -func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { - if val == "" { - return val - } - isCookie := slices.ContainsFunc(field.Sources, func(src bridgev2.LoginCookieFieldSource) bool { - return src.Type == bridgev2.LoginCookieTypeCookie - }) - if !isCookie { - return val - } - decoded, err := url.PathUnescape(val) - if err != nil { - return val - } - return decoded -} - -func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { - ce.Log.Debug().Any("next_step", step).Msg("Got next login step") +func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) { if step.Instructions != "" { ce.Reply(step.Instructions) } switch step.Type { case bridgev2.LoginStepTypeDisplayAndWait: - doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step, override) + doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step) case bridgev2.LoginStepTypeCookies: (&cookieLoginCommandState{ - Login: login.(bridgev2.LoginProcessCookies), - Data: step.CookiesParams, - Override: override, + Login: login.(bridgev2.LoginProcessCookies), + Data: step.CookiesParams, }).prompt(ce) case bridgev2.LoginStepTypeUserInput: - err := sendUserInputAttachments(ce, step.UserInputParams.Attachments) - if err != nil { - ce.Reply("Failed to send attachments: %v", err) - } (&userInputLoginCommandState{ Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, Data: make(map[string]string), - Override: override, }).promptNext(ce) case bridgev2.LoginStepTypeComplete: - if override != nil && override.ID != step.CompleteParams.UserLoginID { - ce.Log.Info(). - Str("old_login_id", string(override.ID)). - Str("new_login_id", string(step.CompleteParams.UserLoginID)). - Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login") - override.Delete(ce.Ctx, status.BridgeState{ - StateEvent: status.StateLoggedOut, - Reason: "LOGIN_OVERRIDDEN", - }, bridgev2.DeleteOpts{LogoutRemote: true}) - } + // Nothing to do other than instructions default: panic(fmt.Errorf("unknown login step type %q", step.Type)) } diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 391c3685..49769514 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -17,7 +17,8 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/status" + + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -41,12 +42,10 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandSendAccountData, CommandResetNetwork, - CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, - CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, + CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, + CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, - CommandSudo, CommandDoIn, + CommandResolveIdentifier, CommandStartChat, CommandSearch, ) return proc } @@ -73,18 +72,16 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. Step: status.MsgStepCommand, Status: event.MessageStatusSuccess, } - logCopy := zerolog.Ctx(ctx).With().Logger() - log := &logCopy defer func() { statusInfo := &bridgev2.MessageStatusEventInfo{ - RoomID: roomID, - SourceEventID: eventID, - EventType: event.EventMessage, - Sender: user.MXID, + RoomID: roomID, + EventID: eventID, + EventType: event.EventMessage, + Sender: user.MXID, } err := recover() if err != nil { - logEvt := log.Error(). + logEvt := zerolog.Ctx(ctx).Error(). Bytes(zerolog.ErrorStackFieldName, debug.Stack()) if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) @@ -111,36 +108,29 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") portal, err := proc.bridge.GetPortalByMXID(ctx, roomID) if err != nil { - log.Err(err).Msg("Failed to get portal") // :( } ce := &Event{ - Bot: proc.bridge.Bot, - Bridge: proc.bridge, - Portal: portal, - Processor: proc, - RoomID: roomID, - OrigRoomID: roomID, - EventID: eventID, - User: user, - Command: command, - Args: args[1:], - RawArgs: rawArgs, - ReplyTo: replyTo, - Ctx: ctx, - Log: log, + Bot: proc.bridge.Bot, + Bridge: proc.bridge, + Portal: portal, + Processor: proc, + RoomID: roomID, + EventID: eventID, + User: user, + Command: command, + Args: args[1:], + RawArgs: rawArgs, + ReplyTo: replyTo, + Ctx: ctx, MessageStatus: ms, } - proc.handleCommand(ctx, ce, message, args) -} -func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage string, origArgs []string) { realCommand, ok := proc.aliases[ce.Command] if !ok { realCommand = ce.Command } - log := zerolog.Ctx(ctx) var handler MinimalCommandHandler handler, ok = proc.handlers[realCommand] @@ -148,22 +138,23 @@ func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage state := LoadCommandState(ce.User) if state != nil && state.Next != nil { ce.Command = "" - ce.RawArgs = origMessage - ce.Args = origArgs + ce.RawArgs = message + ce.Args = args ce.Handler = state.Next - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str("action", state.Action) - }) + log := zerolog.Ctx(ctx).With().Str("action", state.Action).Logger() + ce.Log = &log + ce.Ctx = log.WithContext(ctx) log.Debug().Msg("Received reply to command state") state.Next.Run(ce) } else { - zerolog.Ctx(ctx).Debug().Str("mx_command", ce.Command).Msg("Received unknown command") + zerolog.Ctx(ctx).Debug().Str("mx_command", command).Msg("Received unknown command") ce.Reply("Unknown command, use the `help` command for help.") } } else { - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str("mx_command", ce.Command) - }) + log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() + ctx = log.WithContext(ctx) + ce.Log = &log + ce.Ctx = ctx log.Debug().Msg("Received command") ce.Handler = handler handler.Run(ce) diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index 94c19739..af756c87 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) { } onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly var relay *bridgev2.UserLogin - if len(ce.Args) == 0 && ce.Portal.Receiver == "" { + if len(ce.Args) == 0 { relay = ce.User.GetDefaultLogin() isLoggedIn := relay != nil if onlySetDefaultRelays { @@ -73,19 +73,9 @@ func fnSetRelay(ce *Event) { } } } else { - var targetID networkid.UserLoginID - if ce.Portal.Receiver != "" { - targetID = ce.Portal.Receiver - if len(ce.Args) > 0 && ce.Args[0] != string(targetID) { - ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID) - return - } - } else { - targetID = networkid.UserLoginID(ce.Args[0]) - } - relay = ce.Bridge.GetCachedUserLoginByID(targetID) + relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) if relay == nil { - ce.Reply("User login with ID `%s` not found", targetID) + ce.Reply("User login with ID `%s` not found", ce.Args[0]) return } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { // All good diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index c7b05a6e..24c8a488 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,21 +8,14 @@ package commands import ( "context" - "errors" "fmt" - "html" - "maps" - "slices" "strings" "time" - "github.com/rs/zerolog" + "golang.org/x/net/html" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/provisionutil" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) @@ -35,60 +28,22 @@ var CommandResolveIdentifier = &FullHandler{ Args: "[_login ID_] <_identifier_>", }, RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], -} - -var CommandSyncChat = &FullHandler{ - Func: func(ce *Event) { - login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false) - if err != nil { - ce.Log.Err(err).Msg("Failed to find login for sync") - ce.Reply("Failed to find login: %v", err) - return - } else if login == nil { - ce.Reply("No login found for sync") - return - } - info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal) - if err != nil { - ce.Log.Err(err).Msg("Failed to get chat info for sync") - ce.Reply("Failed to get chat info: %v", err) - return - } - ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{}) - ce.React("✅️") - }, - Name: "sync-portal", - Help: HelpMeta{ - Section: HelpSectionChats, - Description: "Sync the current portal room", - }, - RequiresPortal: true, - RequiresLogin: true, } var CommandStartChat = &FullHandler{ - Func: fnResolveIdentifier, - Name: "start-chat", - Aliases: []string{"pm"}, + Func: fnResolveIdentifier, + Name: "start-chat", Help: HelpMeta{ Section: HelpSectionChats, Description: "Start a direct chat with the given user", Args: "[_login ID_] <_identifier_>", }, RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } -func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { - var remainingArgs []string - if len(ce.Args) > 1 { - remainingArgs = ce.Args[1:] - } - var login *bridgev2.UserLogin - if len(ce.Args) > 0 { - login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) - } +func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { + remainingArgs := ce.Args[1:] + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) if login == nil || login.UserMXID != ce.User.MXID { remainingArgs = ce.Args login = ce.User.GetDefaultLogin() @@ -100,13 +55,24 @@ func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (* return login, api, remainingArgs } -func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string { - if resp.MXID != "" { - return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL()) - } else if resp.Name != "" { - return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name) +func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string { + var targetName string + var targetMXID id.UserID + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ctx, resp.UserInfo) + } + targetName = resp.Ghost.Name + targetMXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + targetName = *resp.UserInfo.Name + } + if targetMXID != "" { + return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) + } else if targetName != "" { + return fmt.Sprintf("`%s` / %s", resp.UserID, targetName) } else { - return fmt.Sprintf("`%s`", resp.ID) + return fmt.Sprintf("`%s`", resp.UserID) } } @@ -119,137 +85,65 @@ func fnResolveIdentifier(ce *Event) { if api == nil { return } - allLogins := ce.User.GetUserLogins() - createChat := ce.Command == "start-chat" || ce.Command == "pm" + createChat := ce.Command == "start-chat" identifier := strings.Join(identifierParts, " ") - resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) - for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ { - resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat) - } + resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) if err != nil { + ce.Log.Err(err).Msg("Failed to resolve identifier") ce.Reply("Failed to resolve identifier: %v", err) return } else if resp == nil { ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) return } - formattedName := formatResolveIdentifierResult(resp) + formattedName := formatResolveIdentifierResult(ce.Ctx, resp) if createChat { - name := resp.Portal.Name - if name == "" { - name = resp.Portal.MXID.String() + if resp.Chat == nil { + ce.Reply("Interface error: network connector did not return chat for create chat request") + return } - if !resp.JustCreated { - ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) + portal := resp.Chat.Portal + if portal == nil { + portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey) + if err != nil { + ce.Log.Err(err).Msg("Failed to get portal") + ce.Reply("Failed to get portal: %v", err) + return + } + } + if resp.Chat.PortalInfo == nil { + resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get portal info") + ce.Reply("Failed to get portal info: %v", err) + return + } + } + if portal.MXID != "" { + name := portal.Name + if name == "" { + name = portal.MXID.String() + } + portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{}) + ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) } else { - ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) + err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo) + if err != nil { + ce.Log.Err(err).Msg("Failed to create room") + ce.Reply("Failed to create room: %v", err) + return + } + name := portal.Name + if name == "" { + name = portal.MXID.String() + } + ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) } } else { ce.Reply("Found %s", formattedName) } } -var CommandCreateGroup = &FullHandler{ - Func: fnCreateGroup, - Name: "create-group", - Aliases: []string{"create"}, - Help: HelpMeta{ - Section: HelpSectionChats, - Description: "Create a new group chat for the current Matrix room", - Args: "[_group type_]", - }, - RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI], -} - -func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) { - evt, err := provider.GetStateEvent(ctx, roomID, evtType, "") - if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation") - } else if evt != nil { - content, _ = evt.Content.Parsed.(T) - } - return -} - -func fnCreateGroup(ce *Event) { - ce.Bridge.Matrix.GetCapabilities() - login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group") - if api == nil { - return - } - stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState) - if !ok { - ce.Reply("Matrix connector doesn't support fetching room state") - return - } - members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) - if err != nil { - ce.Log.Err(err).Msg("Failed to get room members for group creation") - ce.Reply("Failed to get room members: %v", err) - return - } - caps := ce.Bridge.Network.GetCapabilities() - params := &bridgev2.GroupCreateParams{ - Username: "", - Participants: make([]networkid.UserID, 0, len(members)-2), - Parent: nil, // TODO check space parent event - Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider), - Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider), - Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider), - Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider), - RoomID: ce.RoomID, - } - for userID, member := range members { - if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() { - continue - } - if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok { - params.Participants = append(params.Participants, parsedUserID) - } else if !ce.Bridge.Config.SplitPortals { - if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil { - ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member") - } else if user != nil { - // TODO add user logins to participants - //for _, login := range user.GetUserLogins() { - // params.Participants = append(params.Participants, login.GetUserID()) - //} - } - } - } - - if len(caps.Provisioning.GroupCreation) == 0 { - ce.Reply("No group creation types defined in network capabilities") - return - } else if len(remainingArgs) > 0 { - params.Type = remainingArgs[0] - } else if len(caps.Provisioning.GroupCreation) == 1 { - for params.Type = range caps.Provisioning.GroupCreation { - // The loop assigns the variable we want - } - } else { - types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `") - ce.Reply("Please specify type of group to create: `%s`", types) - return - } - resp, err := provisionutil.CreateGroup(ce.Ctx, login, params) - if err != nil { - ce.Reply("Failed to create group: %v", err) - return - } - var postfix string - if len(resp.FailedParticipants) > 0 { - failedParticipantsStrings := make([]string, len(resp.FailedParticipants)) - i := 0 - for participantID, meta := range resp.FailedParticipants { - failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason) - i++ - } - postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n") - } - ce.Reply("Successfully created group `%s`%s", resp.ID, postfix) -} - var CommandSearch = &FullHandler{ Func: fnSearch, Name: "search", @@ -259,7 +153,6 @@ var CommandSearch = &FullHandler{ Args: "<_query_>", }, RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI], } func fnSearch(ce *Event) { @@ -267,67 +160,35 @@ func fnSearch(ce *Event) { ce.Reply("Usage: `$cmdprefix search `") return } - login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") + _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") if api == nil { return } - resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " ")) + results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " ")) if err != nil { + ce.Log.Err(err).Msg("Failed to search for users") ce.Reply("Failed to search for users: %v", err) return } - resultsString := make([]string, len(resp.Results)) - for i, res := range resp.Results { - formattedName := formatResolveIdentifierResult(res) + resultsString := make([]string, len(results)) + for i, res := range results { + formattedName := formatResolveIdentifierResult(ce.Ctx, res) resultsString[i] = fmt.Sprintf("* %s", formattedName) - if res.Portal != nil && res.Portal.MXID != "" { - portalName := res.Portal.Name - if portalName == "" { - portalName = res.Portal.MXID.String() + if res.Chat != nil { + if res.Chat.Portal == nil { + res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey) + if err != nil { + ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal") + } + } + if res.Chat.Portal != nil { + portalName := res.Chat.Portal.Name + if portalName == "" { + portalName = res.Chat.Portal.MXID.String() + } + resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL()) } - resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL()) } } ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) } - -var CommandMute = &FullHandler{ - Func: fnMute, - Name: "mute", - Aliases: []string{"unmute"}, - Help: HelpMeta{ - Section: HelpSectionChats, - Description: "Mute or unmute a chat on the remote network", - Args: "[duration]", - }, - RequiresPortal: true, - RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI], -} - -func fnMute(ce *Event) { - _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats") - var mutedUntil int64 - if ce.Command == "mute" { - mutedUntil = -1 - if len(ce.Args) > 0 { - duration, err := time.ParseDuration(ce.Args[0]) - if err != nil { - ce.Reply("Invalid duration: %v", err) - return - } - mutedUntil = time.Now().Add(duration).UnixMilli() - } - } - err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{ - MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{ - Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil}, - Portal: ce.Portal, - }, - }) - if err != nil { - ce.Reply("Failed to %s chat: %v", ce.Command, err) - } else { - ce.React("✅️") - } -} diff --git a/bridgev2/commands/sudo.go b/bridgev2/commands/sudo.go deleted file mode 100644 index f05ca1bb..00000000 --- a/bridgev2/commands/sudo.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "strings" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var CommandSudo = &FullHandler{ - Func: fnSudo, - Name: "sudo", - Aliases: []string{"doas", "do-as", "runas", "run-as"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Run a command as a different user.", - Args: "[--create] <_user ID_> <_command_> [_args..._]", - }, - RequiresAdmin: true, -} - -func fnSudo(ce *Event) { - forceNonexistentUser := len(ce.Args) > 0 && strings.ToLower(ce.Args[0]) == "--create" - if forceNonexistentUser { - ce.Args = ce.Args[1:] - } - if len(ce.Args) < 2 { - ce.Reply("Usage: `$cmdprefix sudo [--create] [args...]`") - return - } - targetUserID := id.UserID(ce.Args[0]) - if _, _, err := targetUserID.Parse(); err != nil || len(targetUserID) > id.UserIDMaxLength { - ce.Reply("Invalid user ID `%s`", targetUserID) - return - } - var targetUser *bridgev2.User - var err error - if forceNonexistentUser { - targetUser, err = ce.Bridge.GetUserByMXID(ce.Ctx, targetUserID) - } else { - targetUser, err = ce.Bridge.GetExistingUserByMXID(ce.Ctx, targetUserID) - } - if err != nil { - ce.Log.Err(err).Msg("Failed to get user from database") - ce.Reply("Failed to get user") - return - } else if targetUser == nil { - ce.Reply("User not found. Use `--create` if you want to run commands as a user who has never used the bridge.") - return - } - ce.User = targetUser - origArgs := ce.Args[1:] - ce.Command = strings.ToLower(ce.Args[1]) - ce.Args = ce.Args[2:] - ce.RawArgs = strings.Join(ce.Args, " ") - ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) -} - -var CommandDoIn = &FullHandler{ - Func: fnDoIn, - Name: "doin", - Aliases: []string{"do-in", "runin", "run-in"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Run a command in a different room.", - Args: "<_room ID_> <_command_> [_args..._]", - }, -} - -func fnDoIn(ce *Event) { - if len(ce.Args) < 2 { - ce.Reply("Usage: `$cmdprefix doin [args...]`") - return - } - targetRoomID := id.RoomID(ce.Args[0]) - if !ce.User.Permissions.Admin { - memberInfo, err := ce.Bridge.Matrix.GetMemberInfo(ce.Ctx, targetRoomID, ce.User.MXID) - if err != nil { - ce.Log.Err(err).Msg("Failed to check if user is in doin target room") - ce.Reply("Failed to check if you're in the target room") - return - } else if memberInfo == nil || memberInfo.Membership != event.MembershipJoin { - ce.Reply("You must be in the target room to run commands there") - return - } - } - ce.RoomID = targetRoomID - var err error - ce.Portal, err = ce.Bridge.GetPortalByMXID(ce.Ctx, targetRoomID) - if err != nil { - ce.Log.Err(err).Msg("Failed to get target portal") - ce.Reply("Failed to get portal") - return - } - origArgs := ce.Args[1:] - ce.Command = strings.ToLower(ce.Args[1]) - ce.Args = ce.Args[2:] - ce.RawArgs = strings.Join(ce.Args, " ") - ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) -} diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index 1f920640..fed7452d 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -78,11 +78,6 @@ const ( dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11 WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 ` - markBackfillTaskNotDoneQuery = ` - UPDATE backfill_task - SET is_done = false - WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND user_login_id = $4 - ` getNextBackfillQuery = ` SELECT bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, @@ -91,13 +86,6 @@ const ( WHERE bridge_id = $1 AND next_dispatch_min_ts < $2 AND is_done = false AND user_login_id <> '' ORDER BY next_dispatch_min_ts LIMIT 1 ` - getNextBackfillQueryForPortal = ` - SELECT - bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, - cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts - FROM backfill_task - WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND is_done = false AND user_login_id <> '' - ` deleteBackfillQueueQuery = ` DELETE FROM backfill_task WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 @@ -132,18 +120,10 @@ func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) erro return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) } -func (btq *BackfillTaskQuery) MarkNotDone(ctx context.Context, portalKey networkid.PortalKey, userLoginID networkid.UserLoginID) error { - return btq.Exec(ctx, markBackfillTaskNotDoneQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver, userLoginID) -} - func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) { return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano()) } -func (btq *BackfillTaskQuery) GetNextForPortal(ctx context.Context, portalKey networkid.PortalKey) (*BackfillTask, error) { - return btq.QueryOne(ctx, getNextBackfillQueryForPortal, btq.BridgeID, portalKey.ID, portalKey.Receiver) -} - func (btq *BackfillTaskQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { return btq.Exec(ctx, deleteBackfillQueueQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver) } diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 05abddf0..aa77a232 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,7 +7,13 @@ package database import ( + "encoding/json" + "reflect" + "strings" + "go.mau.fi/util/dbutil" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -27,8 +33,6 @@ type Database struct { UserLogin *UserLoginQuery UserPortal *UserPortalQuery BackfillTask *BackfillTaskQuery - KV *KVQuery - PublicMedia *PublicMediaQuery } type MetaMerger interface { @@ -132,16 +136,6 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa return &BackfillTask{} }), }, - KV: &KVQuery{ - BridgeID: bridgeID, - Database: db, - }, - PublicMedia: &PublicMediaQuery{ - BridgeID: bridgeID, - QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia { - return &PublicMedia{} - }), - }, } } @@ -152,3 +146,55 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } + +func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { + if val, found := m[key]; found { + floatVal, ok := val.(float64) + if ok { + return T(floatVal), true + } + tVal, ok := val.(T) + if ok { + return tVal, true + } + } + return 0, false +} + +func unmarshalMerge(input []byte, data any, extra *map[string]any) error { + err := json.Unmarshal(input, data) + if err != nil { + return err + } + err = json.Unmarshal(input, extra) + if err != nil { + return err + } + if *extra == nil { + *extra = make(map[string]any) + } + return nil +} + +func marshalMerge(data any, extra map[string]any) ([]byte, error) { + if extra == nil { + return json.Marshal(data) + } + merged := make(map[string]any) + maps.Copy(merged, extra) + dataRef := reflect.ValueOf(data).Elem() + dataType := dataRef.Type() + for _, field := range reflect.VisibleFields(dataType) { + parts := strings.Split(field.Tag.Get("json"), ",") + if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { + continue + } + fieldVal := dataRef.FieldByIndex(field.Index) + if fieldVal.IsZero() { + delete(merged, parts[0]) + } else { + merged[parts[0]] = fieldVal.Interface() + } + } + return json.Marshal(merged) +} diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index df36b205..23db1448 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -12,94 +12,56 @@ import ( "time" "go.mau.fi/util/dbutil" - "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -// Deprecated: use [event.DisappearingType] -type DisappearingType = event.DisappearingType +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string -// Deprecated: use constants in event package const ( - DisappearingTypeNone = event.DisappearingTypeNone - DisappearingTypeAfterRead = event.DisappearingTypeAfterRead - DisappearingTypeAfterSend = event.DisappearingTypeAfterSend + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" ) // DisappearingSetting represents a disappearing message timer setting // by combining a type with a timer and an optional start timestamp. type DisappearingSetting struct { - Type event.DisappearingType + Type DisappearingType Timer time.Duration DisappearAt time.Time } -func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting { - if evt == nil || evt.Type == event.DisappearingTypeNone { - return DisappearingSetting{} - } - return DisappearingSetting{ - Type: evt.Type, - Timer: evt.Timer.Duration, - } -} - -func (ds DisappearingSetting) Normalize() DisappearingSetting { - if ds.Type == event.DisappearingTypeNone { - ds.Timer = 0 - } else if ds.Timer == 0 { - ds.Type = event.DisappearingTypeNone - } - return ds -} - -func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting { - ds.DisappearAt = start.Add(ds.Timer) - return ds -} - -func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { - if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { - return &event.BeeperDisappearingTimer{} - } - return &event.BeeperDisappearingTimer{ - Type: ds.Type, - Timer: jsontime.MS(ds.Timer), - } -} - type DisappearingMessageQuery struct { BridgeID networkid.BridgeID *dbutil.QueryHelper[*DisappearingMessage] } type DisappearingMessage struct { - BridgeID networkid.BridgeID - RoomID id.RoomID - EventID id.EventID - Timestamp time.Time + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID DisappearingSetting } const ( upsertDisappearingMessageQuery = ` - INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at ` startDisappearingMessagesQuery = ` UPDATE disappearing_message SET disappear_at=$1 + timer - WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4 - RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' + RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at ` getUpcomingDisappearingMessagesQuery = ` - SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at + SELECT bridge_id, mx_room, mxid, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 - ORDER BY disappear_at LIMIT $3 + ORDER BY disappear_at ` deleteDisappearingMessageQuery = ` DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 @@ -111,12 +73,12 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) } -func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano()) +func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) } -func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit) +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano()) } func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { @@ -124,19 +86,17 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even } func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { - var timestamp int64 var disappearAt sql.NullInt64 - err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt) + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) if err != nil { return nil, err } if disappearAt.Valid { d.DisappearAt = time.Unix(0, disappearAt.Int64) } - d.Timestamp = time.Unix(0, timestamp) return d, nil } func (d *DisappearingMessage) sqlVariables() []any { - return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} + return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} } diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index 16af35ca..c32929ad 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -7,17 +7,12 @@ package database import ( - "bytes" "context" "encoding/hex" - "encoding/json" - "fmt" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) @@ -27,55 +22,6 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } -type ExtraProfile map[string]json.RawMessage - -func (ep *ExtraProfile) Set(key string, value any) error { - if key == "displayname" || key == "avatar_url" { - return fmt.Errorf("cannot set reserved profile key %q", key) - } - marshaled, err := json.Marshal(value) - if err != nil { - return err - } - if *ep == nil { - *ep = make(ExtraProfile) - } - (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled) - return nil -} - -func (ep *ExtraProfile) With(key string, value any) *ExtraProfile { - exerrors.PanicIfNotNil(ep.Set(key, value)) - return ep -} - -func canonicalizeIfObject(data json.RawMessage) json.RawMessage { - if len(data) > 0 && (data[0] == '{' || data[0] == '[') { - return canonicaljson.CanonicalJSONAssumeValid(data) - } - return data -} - -func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) { - if len(*ep) == 0 { - return - } - if *dest == nil { - *dest = make(ExtraProfile) - } - for key, val := range *ep { - if key == "displayname" || key == "avatar_url" { - continue - } - existing, exists := (*dest)[key] - if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) { - (*dest)[key] = val - changed = true - } - } - return -} - type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID @@ -89,14 +35,13 @@ type Ghost struct { ContactInfoSet bool IsBot bool Identifiers []string - ExtraProfile ExtraProfile Metadata any } const ( getGhostBaseQuery = ` SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` @@ -104,14 +49,13 @@ const ( insertGhostQuery = ` INSERT INTO ghost ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ` updateGhostQuery = ` UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, - name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, - identifiers=$11, extra_profile=$12, metadata=$13 + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12 WHERE bridge_id=$1 AND id=$2 ` ) @@ -142,7 +86,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { &g.BridgeID, &g.ID, &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, ) if err != nil { return nil, err @@ -172,6 +116,6 @@ func (g *Ghost) sqlVariables() []any { g.BridgeID, g.ID, g.Name, g.AvatarID, avatarHash, g.AvatarMXC, g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, } } diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go deleted file mode 100644 index bca26ed5..00000000 --- a/bridgev2/database/kvstore.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "errors" - - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/bridgev2/networkid" -) - -type Key string - -const ( - KeySplitPortalsEnabled Key = "split_portals_enabled" - KeyBridgeInfoVersion Key = "bridge_info_version" - KeyEncryptionStateResynced Key = "encryption_state_resynced" - KeyRecoveryKey Key = "recovery_key" -) - -type KVQuery struct { - BridgeID networkid.BridgeID - *dbutil.Database -} - -const ( - getKVQuery = `SELECT value FROM kv_store WHERE bridge_id = $1 AND key = $2` - setKVQuery = ` - INSERT INTO kv_store (bridge_id, key, value) VALUES ($1, $2, $3) - ON CONFLICT (bridge_id, key) DO UPDATE SET value = $3 - ` -) - -func (kvq *KVQuery) Get(ctx context.Context, key Key) string { - var value string - err := kvq.QueryRow(ctx, getKVQuery, kvq.BridgeID, key).Scan(&value) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - zerolog.Ctx(ctx).Err(err).Str("key", string(key)).Msg("Failed to get key from kvstore") - } - return value -} - -func (kvq *KVQuery) Set(ctx context.Context, key Key, value string) { - _, err := kvq.Exec(ctx, setKVQuery, kvq.BridgeID, key, value) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Str("key", string(key)). - Str("value", value). - Msg("Failed to set key in kvstore") - } -} diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 4fd599a8..8173ad05 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -11,12 +11,9 @@ import ( "crypto/sha256" "database/sql" "encoding/base64" - "fmt" "strings" - "sync" "time" - "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" @@ -27,7 +24,6 @@ type MessageQuery struct { BridgeID networkid.BridgeID MetaType MetaTypeCreator *dbutil.QueryHelper[*Message] - chunkDeleteLock sync.Mutex } type Message struct { @@ -37,43 +33,36 @@ type Message struct { PartID networkid.PartID MXID id.EventID - Room networkid.PortalKey - SenderID networkid.UserID - SenderMXID id.UserID - Timestamp time.Time - EditCount int - IsDoublePuppeted bool + Room networkid.PortalKey + SenderID networkid.UserID + SenderMXID id.UserID + Timestamp time.Time + EditCount int ThreadRoot networkid.MessageID ReplyTo networkid.MessageOptionalPartID - SendTxnID networkid.RawTransactionID - Metadata any } const ( getMessageBaseQuery = ` SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, - send_txn_id, metadata + timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2` getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` - getMessageByTxnIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND (mxid=$3 OR send_txn_id=$4)` getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1` - getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1` - getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` - getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` - getLastNonFakeMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 AND mxid NOT LIKE '~fake:%' ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` countMessagesInPortalQuery = ` SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 @@ -82,17 +71,15 @@ const ( insertMessageQuery = ` INSERT INTO message ( bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, - send_txn_id, metadata + timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING rowid ` updateMessageQuery = ` UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8, - timestamp=$9, edit_count=$10, double_puppeted=$11, thread_root_id=$12, reply_to_id=$13, - reply_to_part_id=$14, send_txn_id=$15, metadata=$16 - WHERE bridge_id=$1 AND rowid=$17 + timestamp=$9, edit_count=$10, thread_root_id=$11, reply_to_id=$12, reply_to_part_id=$13, metadata=$14 + WHERE bridge_id=$1 AND rowid=$15 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 @@ -100,10 +87,6 @@ const ( deleteMessagePartByRowIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 ` - deleteMessageChunkQuery = ` - DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5 - ` - getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1` ) func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { @@ -118,10 +101,6 @@ func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Me return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid) } -func (mq *MessageQuery) GetPartByTxnID(ctx context.Context, receiver networkid.UserLoginID, mxid id.EventID, txnID networkid.RawTransactionID) (*Message, error) { - return mq.QueryOne(ctx, getMessageByTxnIDQuery, mq.BridgeID, receiver, mxid, txnID) -} - func (mq *MessageQuery) GetLastPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, receiver, id) } @@ -146,10 +125,6 @@ func (mq *MessageQuery) GetLastPartAtOrBeforeTime(ctx context.Context, portal ne return mq.QueryOne(ctx, getLastMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) } -func (mq *MessageQuery) GetLastNonFakePartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) { - return mq.QueryOne(ctx, getLastNonFakeMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) -} - func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) { return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } @@ -166,10 +141,6 @@ func (mq *MessageQuery) GetLastThreadMessage(ctx context.Context, portal network return mq.QueryOne(ctx, getLastMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) } -func (mq *MessageQuery) GetLastNInPortal(ctx context.Context, portal networkid.PortalKey, n int) ([]*Message, error) { - return mq.QueryMany(ctx, getLastNInPortal, mq.BridgeID, portal.ID, portal.Receiver, n) -} - func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.ensureHasMetadata(mq.MetaType).sqlVariables()...).Scan(&msg.RowID) @@ -188,85 +159,6 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } -func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) { - res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) { - err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID) - return -} - -const deleteChunkSize = 100_000 - -func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error { - if mq.GetDB().Dialect != dbutil.SQLite { - return nil - } - log := zerolog.Ctx(ctx).With(). - Str("action", "delete messages in chunks"). - Stringer("portal_key", portal). - Logger() - if !mq.chunkDeleteLock.TryLock() { - log.Warn().Msg("Portal deletion lock is being held, waiting...") - mq.chunkDeleteLock.Lock() - log.Debug().Msg("Acquired portal deletion lock after waiting") - } - defer mq.chunkDeleteLock.Unlock() - total, err := mq.CountMessagesInPortal(ctx, portal) - if err != nil { - return fmt.Errorf("failed to count messages in portal: %w", err) - } else if total < deleteChunkSize/3 { - return nil - } - globalMaxRowID, err := mq.getMaxRowID(ctx) - if err != nil { - return fmt.Errorf("failed to get max row ID: %w", err) - } - log.Debug(). - Int("total_count", total). - Int64("global_max_row_id", globalMaxRowID). - Msg("Portal has lots of messages, deleting in chunks to avoid database locks") - maxRowID := int64(deleteChunkSize) - globalMaxRowID += deleteChunkSize * 1.2 - var dbTimeUsed time.Duration - globalStart := time.Now() - for total > 500 && maxRowID < globalMaxRowID { - start := time.Now() - count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID) - duration := time.Since(start) - dbTimeUsed += duration - if err != nil { - return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err) - } - total -= int(count) - maxRowID += deleteChunkSize - sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) - log.Debug(). - Int64("max_row_id", maxRowID). - Int64("deleted_count", count). - Int("remaining_count", total). - Dur("duration", duration). - Dur("sleep_time", sleepTime). - Msg("Deleted chunk of messages") - select { - case <-time.After(sleepTime): - case <-ctx.Done(): - return ctx.Err() - } - } - log.Debug(). - Int("remaining_count", total). - Dur("db_time_used", dbTimeUsed). - Dur("total_duration", time.Since(globalStart)). - Msg("Finished chunked delete of messages in portal") - return nil -} - func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) return @@ -274,28 +166,22 @@ func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 - var threadRootID, replyToID, replyToPartID, sendTxnID sql.NullString - var doublePuppeted sql.NullBool + var threadRootID, replyToID, replyToPartID sql.NullString err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, - ×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, &sendTxnID, - dbutil.JSON{Data: m.Metadata}, + ×tamp, &m.EditCount, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err } m.Timestamp = time.Unix(0, timestamp) m.ThreadRoot = networkid.MessageID(threadRootID.String) - m.IsDoublePuppeted = doublePuppeted.Valid if replyToID.Valid { m.ReplyTo.MessageID = networkid.MessageID(replyToID.String) if replyToPartID.Valid { m.ReplyTo.PartID = (*networkid.PartID)(&replyToPartID.String) } } - if sendTxnID.Valid { - m.SendTxnID = networkid.RawTransactionID(sendTxnID.String) - } return m, nil } @@ -309,8 +195,7 @@ func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message { func (m *Message) sqlVariables() []any { return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, - m.Timestamp.UnixNano(), m.EditCount, m.IsDoublePuppeted, dbutil.StrPtr(m.ThreadRoot), - dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.StrPtr(m.SendTxnID), + m.Timestamp.UnixNano(), m.EditCount, dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.JSON{Data: m.Metadata}, } } @@ -320,9 +205,6 @@ func (m *Message) updateSQLVariables() []any { } const FakeMXIDPrefix = "~fake:" -const TxnMXIDPrefix = "~txn:" -const NetworkTxnMXIDPrefix = TxnMXIDPrefix + "network:" -const RandomTxnMXIDPrefix = TxnMXIDPrefix + "random:" func (m *Message) SetFakeMXID() { hash := sha256.Sum256([]byte(m.ID)) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 0e6be286..bc1f2658 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -16,7 +16,6 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -35,53 +34,35 @@ type PortalQuery struct { *dbutil.QueryHelper[*Portal] } -type CapStateFlags uint32 - -func (csf CapStateFlags) Has(flag CapStateFlags) bool { - return csf&flag != 0 -} - -const ( - CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota -) - -type CapabilityState struct { - Source networkid.UserLoginID `json:"source"` - ID string `json:"id"` - Flags CapStateFlags `json:"flags"` -} - type Portal struct { BridgeID networkid.BridgeID networkid.PortalKey MXID id.RoomID - ParentKey networkid.PortalKey - RelayLoginID networkid.UserLoginID - OtherUserID networkid.UserID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - NameIsCustom bool - InSpace bool - MessageRequest bool - RoomType RoomType - Disappear DisappearingSetting - CapState CapabilityState - Metadata any + ParentID networkid.PortalID + RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + NameIsCustom bool + InSpace bool + RoomType RoomType + Disappear DisappearingSetting + Metadata any } const ( getPortalBaseQuery = ` - SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, + SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, topic_set, avatar_set, name_is_custom, in_space, message_request, - room_type, disappear_type, disappear_timer, cap_state, + name_set, topic_set, avatar_set, name_is_custom, in_space, + room_type, disappear_type, disappear_timer, metadata FROM portal ` @@ -89,71 +70,38 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` - getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` - getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` - getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` + getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1` insertPortalQuery = ` INSERT INTO portal ( bridge_id, id, receiver, mxid, - parent_id, parent_receiver, relay_login_id, other_user_id, + parent_id, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, topic_set, name_is_custom, in_space, message_request, - room_type, disappear_type, disappear_timer, cap_state, + name_set, avatar_set, topic_set, name_is_custom, in_space, + room_type, disappear_type, disappear_timer, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, - CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END + $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, + CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` updatePortalQuery = ` UPDATE portal - SET mxid=$4, parent_id=$5, parent_receiver=$6, - relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, - other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, - name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19, - room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24 + SET mxid=$4, parent_id=$5, relay_login_id=cast($6 AS TEXT), relay_bridge_id=CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, + other_user_id=$7, name=$8, topic=$9, avatar_id=$10, avatar_hash=$11, avatar_mxc=$12, + name_set=$13, avatar_set=$14, topic_set=$15, name_is_custom=$16, in_space=$17, + room_type=$18, disappear_type=$19, disappear_timer=$20, metadata=$21 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` DELETE FROM portal WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` - reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` - migrateToSplitPortalsQuery = ` - UPDATE portal - SET receiver=new_receiver - FROM ( - SELECT bridge_id, id, COALESCE(( - SELECT login_id - FROM user_portal - WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' - LIMIT 1 - ), ( - SELECT login_id - FROM user_portal - WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id - LIMIT 1 - ), ( - SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 - ), '') AS new_receiver - FROM portal - WHERE receiver='' AND bridge_id=$1 - ) updates - WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS ( - SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver - ) - ` - fixParentsAfterSplitPortalMigrationQuery = ` - UPDATE portal - SET parent_receiver=receiver - WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'' - AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver); - ` + reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` ) func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { @@ -180,10 +128,6 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) } -func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) { - return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID) -} - func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) } @@ -192,12 +136,8 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid. return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) } -func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { - return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID) -} - -func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { - return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) +func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) { + return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) } func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalKey) error { @@ -218,33 +158,17 @@ func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) erro return pq.Exec(ctx, deletePortalQuery, pq.BridgeID, key.ID, key.Receiver) } -func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) { - res, err := pq.GetDB().Exec(ctx, migrateToSplitPortalsQuery, pq.BridgeID) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) { - res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { - var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString + var mxid, parentID, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 var avatarHash string err := row.Scan( &p.BridgeID, &p.ID, &p.Receiver, &mxid, - &parentID, &parentReceiver, &relayLoginID, &otherUserID, + &parentID, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, - &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.RoomType, &disappearType, &disappearTimer, - dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: p.Metadata}, ) if err != nil { return nil, err @@ -257,18 +181,13 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } if disappearType.Valid { p.Disappear = DisappearingSetting{ - Type: event.DisappearingType(disappearType.String), + Type: DisappearingType(disappearType.String), Timer: time.Duration(disappearTimer.Int64), } } p.MXID = id.RoomID(mxid.String) p.OtherUserID = networkid.UserID(otherUserID.String) - if parentID.Valid { - p.ParentKey = networkid.PortalKey{ - ID: networkid.PortalID(parentID.String), - Receiver: networkid.UserLoginID(parentReceiver.String), - } - } + p.ParentID = networkid.PortalID(parentID.String) p.RelayLoginID = networkid.UserLoginID(relayLoginID.String) return p, nil } @@ -287,10 +206,10 @@ func (p *Portal) sqlVariables() []any { } return []any{ p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), - dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), + dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, - p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), - dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: p.Metadata}, } } diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go deleted file mode 100644 index b667399c..00000000 --- a/bridgev2/database/publicmedia.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "time" - - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/id" -) - -type PublicMediaQuery struct { - BridgeID networkid.BridgeID - *dbutil.QueryHelper[*PublicMedia] -} - -type PublicMedia struct { - BridgeID networkid.BridgeID - PublicID string - MXC id.ContentURI - Keys *attachment.EncryptedFile - MimeType string - Expiry time.Time -} - -const ( - upsertPublicMediaQuery = ` - INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry - ` - getPublicMediaQuery = ` - SELECT bridge_id, public_id, mxc, keys, mimetype, expiry - FROM public_media WHERE bridge_id=$1 AND public_id=$2 - ` -) - -func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error { - ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID) - return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...) -} - -func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) { - return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID) -} - -func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) { - var expiry sql.NullInt64 - var mimetype sql.NullString - err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry) - if err != nil { - return nil, err - } - if expiry.Valid { - pm.Expiry = time.Unix(0, expiry.Int64) - } - pm.MimeType = mimetype.String - return pm, nil -} - -func (pm *PublicMedia) sqlVariables() []any { - return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)} -} diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index b65a5c38..08ab2c8e 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -41,11 +41,11 @@ const ( getReactionBaseQuery = ` SELECT bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction ` - getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6` - getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 AND emoji_id=$5 ORDER BY message_part_id ASC LIMIT 1` - getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 ORDER BY timestamp DESC` - getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3` - getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4` + getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` + getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1` + getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 ORDER BY timestamp DESC` + getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` + getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3` getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` upsertReactionQuery = ` INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata) @@ -54,28 +54,28 @@ const ( DO UPDATE SET sender_mxid=excluded.sender_mxid, mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata ` deleteReactionQuery = ` - DELETE FROM reaction WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6 + DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5 ` ) -func (rq *ReactionQuery) GetByID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { - return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, receiver, messageID, messagePartID, senderID, emojiID) +func (rq *ReactionQuery) GetByID(ctx context.Context, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, messageID, messagePartID, senderID, emojiID) } -func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { - return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, receiver, messageID, senderID, emojiID) +func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, messageID, senderID, emojiID) } -func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { - return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, receiver, messageID, senderID) +func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, messageID, senderID) } -func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) ([]*Reaction, error) { - return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, receiver, messageID) +func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, messageID networkid.MessageID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, messageID) } -func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) { - return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, receiver, messageID, partID) +func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, messageID, partID) } func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { @@ -89,7 +89,7 @@ func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error { func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) - return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.Room.Receiver, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID) + return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID) } func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 6092dc24..aeb9522e 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v27 (compatible with v9+): Latest revision +-- v0 -> v16 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -31,6 +31,8 @@ CREATE TABLE portal ( mxid TEXT, parent_id TEXT, + -- This is not accessed by the bridge, it's only used for the portal parent foreign key. + -- Parent groups are probably never DMs, so they don't need a receiver. parent_receiver TEXT NOT NULL DEFAULT '', relay_bridge_id TEXT, @@ -48,11 +50,9 @@ CREATE TABLE portal ( topic_set BOOLEAN NOT NULL, name_is_custom BOOLEAN NOT NULL DEFAULT false, in_space BOOLEAN NOT NULL, - message_request BOOLEAN NOT NULL DEFAULT false, room_type TEXT NOT NULL, disappear_type TEXT, disappear_timer BIGINT, - cap_state jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id, receiver), @@ -64,8 +64,6 @@ CREATE TABLE portal ( REFERENCES user_login (bridge_id, id) ON DELETE SET NULL ON UPDATE CASCADE ); -CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); -CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); CREATE TABLE ghost ( bridge_id TEXT NOT NULL, @@ -80,7 +78,6 @@ CREATE TABLE ghost ( contact_info_set BOOLEAN NOT NULL, is_bot BOOLEAN NOT NULL, identifiers jsonb NOT NULL, - extra_profile jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) @@ -92,7 +89,7 @@ CREATE TABLE message ( -- would try to set bridge_id to null as well. -- only: sqlite (line commented) --- rowid INTEGER PRIMARY KEY, +-- rowid INTEGER PRIMARY KEY, -- only: postgres rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, @@ -107,11 +104,9 @@ CREATE TABLE message ( sender_mxid TEXT NOT NULL, timestamp BIGINT NOT NULL, edit_count INTEGER NOT NULL, - double_puppeted BOOLEAN, thread_root_id TEXT, reply_to_id TEXT, reply_to_part_id TEXT, - send_txn_id TEXT, metadata jsonb NOT NULL, CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) @@ -120,9 +115,7 @@ CREATE TABLE message ( CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE, - CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id), - CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid), - CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id) + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) ); CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); @@ -130,18 +123,12 @@ CREATE TABLE disappearing_message ( bridge_id TEXT NOT NULL, mx_room TEXT NOT NULL, mxid TEXT NOT NULL, - timestamp BIGINT NOT NULL DEFAULT 0, type TEXT NOT NULL, timer BIGINT NOT NULL, disappear_at BIGINT, - PRIMARY KEY (bridge_id, mxid), - CONSTRAINT disappearing_message_portal_fkey - FOREIGN KEY (bridge_id, mx_room) - REFERENCES portal (bridge_id, mxid) - ON DELETE CASCADE + PRIMARY KEY (bridge_id, mxid) ); -CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); CREATE TABLE reaction ( bridge_id TEXT NOT NULL, @@ -167,8 +154,7 @@ CREATE TABLE reaction ( ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) - ON DELETE CASCADE ON UPDATE CASCADE, - CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid) + ON DELETE CASCADE ON UPDATE CASCADE ); CREATE INDEX reaction_room_idx ON reaction (bridge_id, room_id, room_receiver); @@ -212,22 +198,3 @@ CREATE TABLE backfill_task ( REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE ); - -CREATE TABLE kv_store ( - bridge_id TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - - PRIMARY KEY (bridge_id, key) -); - -CREATE TABLE public_media ( - bridge_id TEXT NOT NULL, - public_id TEXT NOT NULL, - mxc TEXT NOT NULL, - keys jsonb, - mimetype TEXT, - expiry BIGINT, - - PRIMARY KEY (bridge_id, public_id) -); diff --git a/bridgev2/database/upgrades/17-message-mxid-unique.sql b/bridgev2/database/upgrades/17-message-mxid-unique.sql deleted file mode 100644 index ee53b3f0..00000000 --- a/bridgev2/database/upgrades/17-message-mxid-unique.sql +++ /dev/null @@ -1,8 +0,0 @@ --- v17 (compatible with v9+): Add unique constraint for message and reaction mxids -DELETE FROM message WHERE mxid IN (SELECT mxid FROM message GROUP BY mxid HAVING COUNT(*) > 1); --- only: postgres for next 2 lines -ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid); -ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid); --- only: sqlite for next 2 lines -CREATE UNIQUE INDEX message_mxid_unique ON message (bridge_id, mxid); -CREATE UNIQUE INDEX reaction_mxid_unique ON reaction (bridge_id, mxid); diff --git a/bridgev2/database/upgrades/18-kv-store.sql b/bridgev2/database/upgrades/18-kv-store.sql deleted file mode 100644 index 9d233095..00000000 --- a/bridgev2/database/upgrades/18-kv-store.sql +++ /dev/null @@ -1,8 +0,0 @@ --- v18 (compatible with v9+): Add generic key-value store -CREATE TABLE kv_store ( - bridge_id TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - - PRIMARY KEY (bridge_id, key) -); diff --git a/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql deleted file mode 100644 index ec6fe836..00000000 --- a/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v19 (compatible with v9+): Add double puppeted state to messages -ALTER TABLE message ADD COLUMN double_puppeted BOOLEAN; diff --git a/bridgev2/database/upgrades/20-portal-capabilities.sql b/bridgev2/database/upgrades/20-portal-capabilities.sql deleted file mode 100644 index 00bd96ca..00000000 --- a/bridgev2/database/upgrades/20-portal-capabilities.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v20 (compatible with v9+): Add portal capability state -ALTER TABLE portal ADD COLUMN cap_state jsonb; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql deleted file mode 100644 index d1c1ad9a..00000000 --- a/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql +++ /dev/null @@ -1,8 +0,0 @@ --- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid -CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); -DELETE FROM disappearing_message WHERE mx_room NOT IN (SELECT mxid FROM portal WHERE mxid IS NOT NULL); -ALTER TABLE disappearing_message - ADD CONSTRAINT disappearing_message_portal_fkey - FOREIGN KEY (bridge_id, mx_room) - REFERENCES portal (bridge_id, mxid) - ON DELETE CASCADE; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql deleted file mode 100644 index f5468c6b..00000000 --- a/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql +++ /dev/null @@ -1,24 +0,0 @@ --- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid -CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); -CREATE TABLE disappearing_message_new ( - bridge_id TEXT NOT NULL, - mx_room TEXT NOT NULL, - mxid TEXT NOT NULL, - type TEXT NOT NULL, - timer BIGINT NOT NULL, - disappear_at BIGINT, - - PRIMARY KEY (bridge_id, mxid), - CONSTRAINT disappearing_message_portal_fkey - FOREIGN KEY (bridge_id, mx_room) - REFERENCES portal (bridge_id, mxid) - ON DELETE CASCADE -); - -WITH portal_mxids AS (SELECT mxid FROM portal WHERE mxid IS NOT NULL) -INSERT INTO disappearing_message_new (bridge_id, mx_room, mxid, type, timer, disappear_at) -SELECT bridge_id, mx_room, mxid, type, timer, disappear_at -FROM disappearing_message WHERE mx_room IN portal_mxids; - -DROP TABLE disappearing_message; -ALTER TABLE disappearing_message_new RENAME TO disappearing_message; diff --git a/bridgev2/database/upgrades/22-message-send-txn-id.sql b/bridgev2/database/upgrades/22-message-send-txn-id.sql deleted file mode 100644 index 8933984e..00000000 --- a/bridgev2/database/upgrades/22-message-send-txn-id.sql +++ /dev/null @@ -1,6 +0,0 @@ --- v22 (compatible with v9+): Add message send transaction ID column -ALTER TABLE message ADD COLUMN send_txn_id TEXT; --- only: postgres -ALTER TABLE message ADD CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id); --- only: sqlite -CREATE UNIQUE INDEX message_txn_id_unique ON message (bridge_id, room_receiver, send_txn_id); diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql deleted file mode 100644 index ecd00b8d..00000000 --- a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v23 (compatible with v9+): Add event timestamp for disappearing messages -ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0; diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql deleted file mode 100644 index c4290090..00000000 --- a/bridgev2/database/upgrades/24-public-media.sql +++ /dev/null @@ -1,11 +0,0 @@ --- v24 (compatible with v9+): Custom URLs for public media -CREATE TABLE public_media ( - bridge_id TEXT NOT NULL, - public_id TEXT NOT NULL, - mxc TEXT NOT NULL, - keys jsonb, - mimetype TEXT, - expiry BIGINT, - - PRIMARY KEY (bridge_id, public_id) -); diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql deleted file mode 100644 index b9d82a7a..00000000 --- a/bridgev2/database/upgrades/25-message-requests.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v25 (compatible with v9+): Flag for message request portals -ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql deleted file mode 100644 index ae5d8cad..00000000 --- a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql +++ /dev/null @@ -1,3 +0,0 @@ --- v26 (compatible with v9+): Add room index for disappearing message table and portal parents -CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); -CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql deleted file mode 100644 index e8e0549a..00000000 --- a/bridgev2/database/upgrades/27-ghost-extra-profile.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v27 (compatible with v9+): Add column for extra ghost profile metadata -ALTER TABLE ghost ADD COLUMN extra_profile jsonb; diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 00ff01c9..610e7d60 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -12,8 +12,8 @@ import ( "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/id" ) @@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { func (u *UserLogin) sqlVariables() []any { var remoteProfile dbutil.JSON - if !u.RemoteProfile.IsZero() { + if !u.RemoteProfile.IsEmpty() { remoteProfile.Data = &u.RemoteProfile } return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index e928a4c7..278b236b 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -67,9 +67,6 @@ const ( markLoginAsPreferredQuery = ` UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 ` - markAllNotInSpaceQuery = ` - UPDATE user_portal SET in_space=false WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3 - ` deleteUserPortalQuery = ` DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 ` @@ -113,10 +110,6 @@ func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogi return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } -func (upq *UserPortalQuery) MarkAllNotInSpace(ctx context.Context, portal networkid.PortalKey) error { - return upq.Exec(ctx, markAllNotInSpaceQuery, upq.BridgeID, portal.ID, portal.Receiver) -} - func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error { return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver) } diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index b5c37e8f..5f9900a5 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -8,7 +8,6 @@ package bridgev2 import ( "context" - "sync/atomic" "time" "github.com/rs/zerolog" @@ -21,44 +20,27 @@ import ( type DisappearLoop struct { br *Bridge - nextCheck atomic.Pointer[time.Time] - stop atomic.Pointer[context.CancelFunc] + NextCheck time.Time + stop context.CancelFunc } const DisappearCheckInterval = 1 * time.Hour func (dl *DisappearLoop) Start() { log := dl.br.Log.With().Str("component", "disappear loop").Logger() - ctx, stop := context.WithCancel(log.WithContext(context.Background())) - if oldStop := dl.stop.Swap(&stop); oldStop != nil { - (*oldStop)() - } + ctx := log.WithContext(context.Background()) + ctx, dl.stop = context.WithCancel(ctx) log.Debug().Msg("Disappearing message loop starting") for { - nextCheck := time.Now().Add(DisappearCheckInterval) - dl.nextCheck.Store(&nextCheck) - const MessageLimit = 200 - messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit) + dl.NextCheck = time.Now().Add(DisappearCheckInterval) + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval) if err != nil { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { - if len(messages) >= MessageLimit { - lastDisappearTime := messages[len(messages)-1].DisappearAt - log.Debug(). - Int("message_count", len(messages)). - Time("last_due", lastDisappearTime). - Msg("Deleting disappearing messages synchronously and checking again immediately") - // Store the expected next check time to avoid Add spawning unnecessary goroutines. - // This can be in the past, in which case Add will put everything in the db, which is also fine. - dl.nextCheck.Store(&lastDisappearTime) - // If there are many messages, process them synchronously and then check again. - dl.sleepAndDisappear(ctx, messages...) - continue - } go dl.sleepAndDisappear(ctx, messages...) } select { - case <-time.After(time.Until(dl.GetNextCheck())): + case <-time.After(time.Until(dl.NextCheck)): case <-ctx.Done(): log.Debug().Msg("Disappearing message loop stopping") return @@ -66,34 +48,20 @@ func (dl *DisappearLoop) Start() { } } -func (dl *DisappearLoop) GetNextCheck() time.Time { - if dl == nil { - return time.Time{} - } - nextCheck := dl.nextCheck.Load() - if nextCheck == nil { - return time.Time{} - } - return *nextCheck -} - func (dl *DisappearLoop) Stop() { - if dl == nil { - return - } - if stop := dl.stop.Load(); stop != nil { - (*stop)() + if dl.stop != nil { + dl.stop() } } -func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) { - startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS) +func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") return } startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { - return dm.DisappearAt.After(dl.GetNextCheck()) + return dm.DisappearAt.After(dl.NextCheck) }) slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { return a.DisappearAt.Compare(b.DisappearAt) @@ -110,25 +78,14 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa Stringer("event_id", dm.EventID). Msg("Failed to save disappearing message") } - if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.GetNextCheck()) { - go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm) + if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) { + go dl.sleepAndDisappear(context.WithoutCancel(ctx), dm) } } func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { for _, msg := range dms { - timeUntilDisappear := time.Until(msg.DisappearAt) - if timeUntilDisappear <= 0 { - if ctx.Err() != nil { - return - } - } else { - select { - case <-time.After(timeUntilDisappear): - case <-ctx.Done(): - return - } - } + time.Sleep(time.Until(msg.DisappearAt)) resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: msg.EventID, diff --git a/bridgev2/errors.go b/bridgev2/errors.go index f6677d2e..2834b298 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -9,10 +9,8 @@ package bridgev2 import ( "errors" "fmt" - "net/http" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" ) // ErrIgnoringRemoteEvent can be returned by [RemoteMessage.ConvertMessage] or [RemoteEdit.ConvertEdit] @@ -23,9 +21,8 @@ var ErrIgnoringRemoteEvent = errors.New("ignoring remote event") // and a status should not be sent yet. The message will still be saved into the database. var ErrNoStatus = errors.New("omit message status") -// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier or CreateChatWithGhost to signal that -// the identifier is valid, but can't be reached by the current login, and the caller should try the next -// login if there are more. +// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier to signal that the identifier is valid, +// but can't be reached by the current login, and the caller should try the next login if there are more. // // This should generally only be returned when resolving internal IDs (which happens when initiating chats via Matrix). // For example, Google Messages would return this when trying to resolve another login's user ID, @@ -38,58 +35,29 @@ var ErrNotLoggedIn = errors.New("not logged in") // but direct media is not enabled. var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") -var ErrPortalIsDeleted = errors.New("portal is deleted") -var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event") - // Common message status errors var ( - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() - ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) - ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - - ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) - ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) - ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) - - ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) -) - -// Common login interface errors -var ( - ErrInvalidLoginFlowID error = RespError(mautrix.MNotFound.WithMessage("Invalid login flow ID")) + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ) // RespError is a class of error that certain network interface methods can return to ensure that the error @@ -111,14 +79,6 @@ func (re RespError) Is(err error) bool { return errors.Is(err, mautrix.RespError(re)) } -func (re RespError) Write(w http.ResponseWriter) { - mautrix.RespError(re).Write(w) -} - -func (re RespError) WithMessage(msg string, args ...any) RespError { - return RespError(mautrix.RespError(re).WithMessage(msg, args...)) -} - func (re RespError) AppendMessage(append string, args ...any) RespError { re.Err += fmt.Sprintf(append, args...) return re diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 590dd1dc..e4e007cd 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -9,15 +9,12 @@ package bridgev2 import ( "context" "crypto/sha256" - "encoding/json" "fmt" - "maps" "net/http" - "slices" "github.com/rs/zerolog" - "go.mau.fi/util/exerrors" "go.mau.fi/util/exmime" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -88,13 +85,7 @@ func (br *Bridge) GetGhostByMXID(ctx context.Context, mxid id.UserID) (*Ghost, e func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - ghost, err := br.unlockedGetGhostByID(ctx, id, false) - if err != nil { - return nil, err - } else if ghost == nil { - panic(fmt.Errorf("unlockedGetGhostByID(ctx, %q, false) returned nil", id)) - } - return ghost, nil + return br.unlockedGetGhostByID(ctx, id, false) } func (br *Bridge) GetExistingGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { @@ -137,11 +128,10 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32 } type UserInfo struct { - Identifiers []string - Name *string - Avatar *Avatar - IsBot *bool - ExtraProfile database.ExtraProfile + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool ExtraUpdates ExtraUpdater[*Ghost] } @@ -162,7 +152,7 @@ func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { } func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { - if ghost.AvatarID == avatar.ID && (avatar.Remove || ghost.AvatarMXC != "") && ghost.AvatarSet { + if ghost.AvatarID == avatar.ID && ghost.AvatarSet { return false } ghost.AvatarID = avatar.ID @@ -172,7 +162,7 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { ghost.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") return true - } else if newHash == ghost.AvatarHash && ghost.AvatarMXC != "" && ghost.AvatarSet { + } else if newHash == ghost.AvatarHash && ghost.AvatarSet { return true } ghost.AvatarHash = newHash @@ -189,9 +179,23 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { return true } -func (ghost *Ghost) getExtraProfileMeta() any { +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { + if identifiers != nil { + slices.Sort(identifiers) + } + if ghost.ContactInfoSet && + (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) && + (isBot == nil || *isBot == ghost.IsBot) { + return false + } + if identifiers != nil { + ghost.Identifiers = identifiers + } + if isBot != nil { + ghost.IsBot = *isBot + } bridgeName := ghost.Bridge.Network.GetName() - baseExtra := &event.BeeperProfileExtra{ + meta := &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), Identifiers: ghost.Identifiers, Service: bridgeName.BeeperBridgeType, @@ -199,36 +203,7 @@ func (ghost *Ghost) getExtraProfileMeta() any { IsBridgeBot: false, IsNetworkBot: ghost.IsBot, } - if len(ghost.ExtraProfile) == 0 { - return baseExtra - } - mergedExtra := maps.Clone(ghost.ExtraProfile) - baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra)) - exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra)) - return mergedExtra -} - -func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool { - if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta { - ghost.ContactInfoSet = false - return false - } - if identifiers != nil { - slices.Sort(identifiers) - } - changed := extraProfile.CopyTo(&ghost.ExtraProfile) - if identifiers != nil { - changed = changed || !slices.Equal(identifiers, ghost.Identifiers) - ghost.Identifiers = identifiers - } - if isBot != nil { - changed = changed || *isBot != ghost.IsBot - ghost.IsBot = *isBot - } - if ghost.ContactInfoSet && !changed { - return false - } - err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) + err := ghost.Intent.SetExtraProfileMeta(ctx, meta) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") } else { @@ -250,7 +225,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { } func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { - if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { + if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } info, err := source.Client.GetUserInfo(ctx, ghost) @@ -260,16 +235,12 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin zerolog.Ctx(ctx).Debug(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). - Bool("has_avatar", ghost.AvatarMXC != ""). - Bool("avatar_set", ghost.AvatarSet). Msg("Updating ghost info in IfNecessary call") ghost.UpdateInfo(ctx, info) } else { zerolog.Ctx(ctx).Trace(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). - Bool("has_avatar", ghost.AvatarMXC != ""). - Bool("avatar_set", ghost.AvatarSet). Msg("No ghost info received in IfNecessary call") } } @@ -297,14 +268,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { } if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update - } else if oldAvatar == "" && !ghost.AvatarSet { - // Special case: nil avatar means we're not expecting one ever, if we don't currently have - // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary. - ghost.AvatarSet = true - update = true } - if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil { - update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update + if info.Identifiers != nil || info.IsBot != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update } if info.ExtraUpdates != nil { update = info.ExtraUpdates(ctx, ghost) || update diff --git a/bridgev2/login.go b/bridgev2/login.go index b8321719..7acccd9a 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -13,7 +13,6 @@ import ( "strings" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" ) // LoginProcess represents a single occurrence of a user logging into the remote network. @@ -33,17 +32,6 @@ type LoginProcess interface { Cancel() } -type LoginProcessWithOverride interface { - LoginProcess - // StartWithOverride starts the process with the intent of re-authenticating an existing login. - // - // The call to this is mutually exclusive with the call to the default Start method. - // - // The user login being overridden will still be logged out automatically - // in case the complete step returns a different login. - StartWithOverride(ctx context.Context, override *UserLogin) (*LoginStep, error) -} - type LoginProcessDisplayAndWait interface { LoginProcess Wait(ctx context.Context) (*LoginStep, error) @@ -160,12 +148,6 @@ type LoginCookiesParams struct { // The snippet will evaluate to a promise that resolves when the relevant fields are found. // Fields that are not present in the promise result must be extracted another way. ExtractJS string `json:"extract_js,omitempty"` - // A regex pattern that the URL should match before the client closes the webview. - // - // The client may submit the login if the user closes the webview after all cookies are collected - // even if this URL is not reached, but it should only automatically close the webview after - // both cookies and the URL match. - WaitForURLPattern string `json:"wait_for_url_pattern,omitempty"` } type LoginInputFieldType string @@ -177,10 +159,6 @@ const ( LoginInputFieldTypeEmail LoginInputFieldType = "email" LoginInputFieldType2FACode LoginInputFieldType = "2fa_code" LoginInputFieldTypeToken LoginInputFieldType = "token" - LoginInputFieldTypeURL LoginInputFieldType = "url" - LoginInputFieldTypeDomain LoginInputFieldType = "domain" - LoginInputFieldTypeSelect LoginInputFieldType = "select" - LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code" ) type LoginInputDataField struct { @@ -192,13 +170,8 @@ type LoginInputDataField struct { Name string `json:"name"` // The description of the field shown to the user. Description string `json:"description"` - // A default value that the client can pre-fill the field with. - DefaultValue string `json:"default_value,omitempty"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` - // For fields of type select, the valid options. - // Pattern may also be filled with a regex that matches the same options. - Options []string `json:"options,omitempty"` // A function that validates the input and optionally cleans it up before it's submitted to the connector. Validate func(string) (string, error) `json:"-"` } @@ -273,23 +246,6 @@ func (f *LoginInputDataField) FillDefaultValidate() { type LoginUserInputParams struct { // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` - - // Attachments to display alongside the input fields. - Attachments []*LoginUserInputAttachment `json:"attachments"` -} - -type LoginUserInputAttachment struct { - Type event.MessageType `json:"type,omitempty"` - FileName string `json:"filename,omitempty"` - Content []byte `json:"content,omitempty"` - Info LoginUserInputAttachmentInfo `json:"info,omitempty"` -} - -type LoginUserInputAttachmentInfo struct { - MimeType string `json:"mimetype,omitempty"` - Width int `json:"w,omitempty"` - Height int `json:"h,omitempty"` - Size int `json:"size,omitempty"` } type LoginCompleteParams struct { diff --git a/bridgev2/matrix/analytics.go b/bridgev2/matrix/analytics.go deleted file mode 100644 index 7eb2a33a..00000000 --- a/bridgev2/matrix/analytics.go +++ /dev/null @@ -1,62 +0,0 @@ -package matrix - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - - "maunium.net/go/mautrix/id" -) - -func (br *Connector) trackSync(userID id.UserID, event string, properties map[string]any) error { - var buf bytes.Buffer - var analyticsUserID string - if br.Config.Analytics.UserID != "" { - analyticsUserID = br.Config.Analytics.UserID - } else { - analyticsUserID = userID.String() - } - err := json.NewEncoder(&buf).Encode(map[string]any{ - "userId": analyticsUserID, - "event": event, - "properties": properties, - }) - if err != nil { - return err - } - - req, err := http.NewRequest(http.MethodPost, br.Config.Analytics.URL, &buf) - if err != nil { - return err - } - req.SetBasicAuth(br.Config.Analytics.Token, "") - resp, err := br.AS.HTTPClient.Do(req) - if err != nil { - return err - } - _ = resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code %d", resp.StatusCode) - } - return nil -} - -func (br *Connector) TrackAnalytics(userID id.UserID, event string, props map[string]any) { - if br.Config.Analytics.Token == "" || br.Config.Analytics.URL == "" { - return - } - - if props == nil { - props = map[string]any{} - } - props["bridge"] = br.Bridge.Network.GetName().BeeperBridgeType - go func() { - err := br.trackSync(userID, event, props) - if err != nil { - br.Log.Err(err).Str("component", "analytics").Str("event", event).Msg("Error tracking event") - } else { - br.Log.Debug().Str("component", "analytics").Str("event", event).Msg("Tracked event") - } - }() -} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 5a2df953..c2b3f7cd 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -10,34 +10,33 @@ import ( "context" "crypto/sha256" "encoding/base64" + "encoding/json" "errors" "fmt" - "net/http" "net/url" "os" "regexp" "strings" "sync" "time" + "unsafe" + "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" - "go.mau.fi/util/exbytes" "go.mau.fi/util/exsync" - "go.mau.fi/util/ptr" "go.mau.fi/util/random" - "golang.org/x/sync/semaphore" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/mediaproxy" @@ -71,7 +70,6 @@ type Connector struct { DoublePuppet *doublePuppetUtil MediaProxy *mediaproxy.MediaProxy - uploadSema *semaphore.Weighted dmaSigKey [32]byte pubMediaSigKey []byte @@ -81,8 +79,6 @@ type Connector struct { MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions - SpecCaps *mautrix.RespCapabilities - specCapsLock sync.Mutex Capabilities *bridgev2.MatrixCapabilities IgnoreUnsupportedServer bool @@ -104,12 +100,7 @@ type Connector struct { var ( _ bridgev2.MatrixConnector = (*Connector)(nil) _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) - _ bridgev2.MatrixConnectorWithArbitraryRoomState = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) - _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) - _ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil) - _ bridgev2.MatrixConnectorWithURLPreviews = (*Connector)(nil) - _ bridgev2.MatrixConnectorWithAnalytics = (*Connector)(nil) ) func NewConnector(cfg *bridgeconfig.Config) *Connector { @@ -117,7 +108,6 @@ func NewConnector(cfg *bridgeconfig.Config) *Connector { c.Config = cfg c.userIDRegex = cfg.MakeUserIDRegex("(.+)") c.MediaConfig.UploadSize = 50 * 1024 * 1024 - c.uploadSema = semaphore.NewWeighted(c.MediaConfig.UploadSize + 1) c.Capabilities = &bridgev2.MatrixCapabilities{} c.doublePuppetIntents = exsync.NewMap[id.UserID, *appservice.IntentAPI]() return c @@ -139,25 +129,16 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { } br.EventProcessor.On(event.EventMessage, br.handleRoomEvent) br.EventProcessor.On(event.EventSticker, br.handleRoomEvent) - br.EventProcessor.On(event.EventUnstablePollStart, br.handleRoomEvent) - br.EventProcessor.On(event.EventUnstablePollResponse, br.handleRoomEvent) br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) - br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) - br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) - br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) - br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) - br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) - br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) - br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.(*commands.Processor).AddHandlers( @@ -179,17 +160,6 @@ func (br *Connector) Start(ctx context.Context) error { if err != nil { return err } - needsStateResync := br.Config.Encryption.Default && - br.Bridge.DB.KV.Get(ctx, database.KeyEncryptionStateResynced) != "true" - if needsStateResync { - dbExists, err := br.StateStore.TableExists(ctx, "mx_version") - if err != nil { - return fmt.Errorf("failed to check if mx_version table exists: %w", err) - } else if !dbExists { - needsStateResync = false - br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") - } - } err = br.StateStore.Upgrade(ctx) if err != nil { return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} @@ -226,66 +196,24 @@ func (br *Connector) Start(ctx context.Context) error { } parsed, _ := url.Parse(br.Bridge.Network.GetName().NetworkURL) if parsed != nil { - br.deterministicEventIDServer = strings.TrimPrefix(parsed.Hostname(), "www.") + br.deterministicEventIDServer = parsed.Hostname() } br.AS.Ready = true if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { br.wsStopPinger = make(chan struct{}, 1) go br.websocketServerPinger() } - if needsStateResync { - br.ResyncEncryptionState(ctx) - } return nil } -func (br *Connector) ResyncEncryptionState(ctx context.Context) { - log := zerolog.Ctx(ctx) - roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID]) - rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, ` - SELECT rooms.room_id - FROM (SELECT DISTINCT(room_id) FROM mx_user_profile WHERE room_id<>'') rooms - LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id - WHERE mx_room_state.encryption IS NULL - `)).AsList() - if err != nil { - log.Err(err).Msg("Failed to get room list to resync state") - return - } - var failedCount, successCount, forbiddenCount int - for _, roomID := range rooms { - if roomID == "" { - continue - } - var outContent *event.EncryptionEventContent - err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) - if errors.Is(err, mautrix.MForbidden) { - // Most likely non-existent room - log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") - forbiddenCount++ - } else if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") - failedCount++ - } else { - successCount++ - } - } - br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") - log.Info(). - Int("success_count", successCount). - Int("forbidden_count", forbiddenCount). - Int("failed_count", failedCount). - Msg("Resynced rooms") -} - func (br *Connector) GetPublicAddress() string { if br.Config.AppService.PublicAddress == "https://bridge.example.com" { return "" } - return strings.TrimRight(br.Config.AppService.PublicAddress, "/") + return br.Config.AppService.PublicAddress } -func (br *Connector) GetRouter() *http.ServeMux { +func (br *Connector) GetRouter() *mux.Router { if br.GetPublicAddress() != "" { return br.AS.Router } @@ -296,37 +224,13 @@ func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { return br.Capabilities } -func sendStopSignal(ch chan struct{}) { - if ch != nil { - select { - case ch <- struct{}{}: - default: - } - } -} - -func (br *Connector) PreStop() { +func (br *Connector) Stop() { br.stopping = true br.AS.Stop() - if stopWebsocket := br.AS.StopWebsocket; stopWebsocket != nil { - stopWebsocket(appservice.ErrWebsocketManualStop) - } - sendStopSignal(br.wsStopPinger) - sendStopSignal(br.wsShortCircuitReconnectBackoff) -} - -func (br *Connector) Stop() { br.EventProcessor.Stop() if br.Crypto != nil { br.Crypto.Stop() } - if wsStopChan := br.wsStopped; wsStopChan != nil { - select { - case <-wsStopChan: - case <-time.After(4 * time.Second): - br.Log.Warn().Msg("Timed out waiting for websocket to close") - } - } } var MinSpecVersion = mautrix.SpecV14 @@ -344,21 +248,16 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) { } func (br *Connector) ensureConnection(ctx context.Context) { - triedToRegister := false for { versions, err := br.Bot.Versions(ctx) if err != nil { - if errors.Is(err, mautrix.MForbidden) && !triedToRegister { + if errors.Is(err, mautrix.MForbidden) { br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") os.Exit(16) } - triedToRegister = true - } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { - br.logInitialRequestError(err, "/versions request failed with auth error") - os.Exit(16) } else { br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") time.Sleep(10 * time.Second) @@ -368,9 +267,6 @@ func (br *Connector) ensureConnection(ctx context.Context) { *br.AS.SpecVersions = *versions br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) - br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) - br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || - (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo) break } } @@ -411,23 +307,50 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Log.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") return } - - br.Bot.EnsureAppserviceConnection(ctx) -} - -func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities { - br.specCapsLock.Lock() - defer br.specCapsLock.Unlock() - if br.SpecCaps != nil { - return br.SpecCaps + var pingResp *mautrix.RespAppservicePing + var txnID string + var retryCount int + const maxRetries = 6 + for { + txnID = br.Bot.TxnID() + pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) + if err == nil { + break + } + var httpErr mautrix.HTTPError + var pingErrBody string + if errors.As(err, &httpErr) && httpErr.RespError != nil { + if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { + pingErrBody = strings.TrimSpace(val) + } + } + outOfRetries := retryCount >= maxRetries + level := zerolog.ErrorLevel + if outOfRetries { + level = zerolog.FatalLevel + } + evt := br.Log.WithLevel(level).Err(err).Str("txn_id", txnID) + if pingErrBody != "" { + bodyBytes := []byte(pingErrBody) + if json.Valid(bodyBytes) { + evt.RawJSON("body", bodyBytes) + } else { + evt.Str("body", pingErrBody) + } + } + if outOfRetries { + evt.Msg("Homeserver -> bridge connection is not working") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") + os.Exit(13) + } + evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") + time.Sleep(5 * time.Second) + retryCount++ } - caps, err := br.Bot.Capabilities(ctx) - if err != nil { - br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver") - return nil - } - br.SpecCaps = caps - return caps + br.Log.Debug(). + Str("txn_id", txnID). + Int64("duration_ms", pingResp.DurationMS). + Msg("Homeserver -> bridge connection works") } func (br *Connector) fetchMediaConfig(ctx context.Context) { @@ -443,7 +366,6 @@ func (br *Connector) fetchMediaConfig(ctx context.Context) { if ok { mfsn.SetMaxFileSize(br.MediaConfig.UploadSize) } - br.uploadSema = semaphore.NewWeighted(br.MediaConfig.UploadSize + 1) } } @@ -495,15 +417,11 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { if br.Websocket { br.hasSentAnyStates = true - return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ + return br.AS.SendWebsocket(&appservice.WebsocketRequest{ Command: "bridge_status", Data: state, }) } else if br.Config.Homeserver.StatusEndpoint != "" { - // Connecting states aren't really relevant unless the bridge runs somewhere with an unreliable network - if state.StateEvent == status.StateConnecting { - return nil - } return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) } else { return nil @@ -515,31 +433,26 @@ func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.Message } func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { - if evt.EventType.IsEphemeral() || evt.SourceEventID == "" { + if evt.EventType.IsEphemeral() || evt.EventID == "" { return "" } log := zerolog.Ctx(ctx) - - if !evt.IsSourceEventDoublePuppeted { - err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) - if err != nil { - log.Err(err).Msg("Failed to send message checkpoint") - } + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) + if err != nil { + log.Err(err).Msg("Failed to send message checkpoint") } - if !ms.DisableMSS && br.Config.Matrix.MessageStatusEvents { mssEvt := ms.ToMSSEvent(evt) - _, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) + _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.SourceEventID). + Stringer("event_id", evt.EventID). Any("mss_content", mssEvt). Msg("Failed to send MSS event") } } - if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice && - (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { content := ms.ToNoticeEvent(evt) if editEvent != "" { content.SetEdit(editEvent) @@ -548,7 +461,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.SourceEventID). + Stringer("event_id", evt.EventID). Str("notice_message", content.Body). Msg("Failed to send notice event") } else { @@ -556,22 +469,22 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 } } if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { - err := br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) + err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.EventID, event.ReceiptTypeRead, nil) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.SourceEventID). + Stringer("event_id", evt.EventID). Msg("Failed to send Matrix delivery receipt") } } return "" } -func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*status.MessageCheckpoint) error { +func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} if br.Websocket { - return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ + return br.AS.SendWebsocket(&appservice.WebsocketRequest{ Command: "message_checkpoint", Data: checkpointsJSON, }) @@ -582,7 +495,7 @@ func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []* return nil } - return checkpointsJSON.SendHTTP(ctx, br.AS.HTTPClient, endpoint, br.AS.Registration.AppToken) + return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) } func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) { @@ -622,38 +535,8 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve return br.Bot.PowerLevels(ctx, roomID) } -func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { - if stateKey == "" { - switch eventType { - case event.StateCreate: - createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) - if err != nil || createEvt != nil { - return createEvt, err - } - case event.StateJoinRules: - joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID) - if err != nil { - return nil, err - } else if joinRulesContent != nil { - return &event.Event{ - Type: event.StateJoinRules, - RoomID: roomID, - StateKey: ptr.Ptr(""), - Content: event.Content{Parsed: joinRulesContent}, - }, nil - } - } - } - return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey) -} - func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID) - if err != nil { - return nil, err - } else if fetched { - return br.Bot.StateStore.GetAllMembers(ctx, roomID) - } + // TODO use cache? members, err := br.Bot.Members(ctx, roomID) if err != nil { return nil, err @@ -674,10 +557,6 @@ func (br *Connector) IsConfusableName(ctx context.Context, roomID id.RoomID, use return br.AS.StateStore.IsConfusableName(ctx, roomID, userID, name) } -func (br *Connector) GetUniqueBridgeID() string { - return fmt.Sprintf("%s/%s", br.Config.Homeserver.Domain, br.Config.AppService.ID) -} - func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) @@ -687,7 +566,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr if intent != nil { intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) } - if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction { + if evt.Type != event.EventEncrypted { err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) if err != nil { return nil, err @@ -703,11 +582,9 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr } func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID { - data := make([]byte, 0, len(roomID)+1+len(messageID)+1+len(partID)) + data := make([]byte, 0, len(roomID)+len(messageID)+len(partID)) data = append(data, roomID...) - data = append(data, 0) data = append(data, messageID...) - data = append(data, 0) data = append(data, partID...) hash := sha256.Sum256(data) @@ -719,11 +596,7 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid. eventID[1+hashB64Len] = ':' copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer) - return id.EventID(exbytes.UnsafeString(eventID)) -} - -func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID { - return id.RoomID(fmt.Sprintf("!%s.%s:%s", key.ID, key.Receiver, br.ServerName())) + return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID))) } func (br *Connector) GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID { @@ -752,7 +625,3 @@ func (br *Connector) HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomI } return nil } - -func (br *Connector) GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) { - return br.Bot.GetURLPreview(ctx, url) -} diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 7f18f1f5..427b369d 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -14,7 +14,6 @@ import ( "fmt" "os" "runtime/debug" - "strings" "sync" "time" @@ -24,7 +23,6 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -38,9 +36,9 @@ func init() { var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.ErrNoSessionFound -var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex -var UnknownMessageIndex = olm.ErrUnknownMessageIndex +var NoSessionFound = crypto.NoSessionFound +var DuplicateMessageIndex = crypto.DuplicateMessageIndex +var UnknownMessageIndex = olm.UnknownMessageIndex type CryptoHelper struct { bridge *Connector @@ -79,7 +77,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { dbutil.ZeroLogger(helper.bridge.Log.With().Str("db_section", "crypto").Logger()), string(helper.bridge.Bridge.ID), helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", strings.ReplaceAll(helper.bridge.Config.AppService.FormatUsername("%"), "_", `\_`), helper.bridge.AS.HomeserverDomain), + fmt.Sprintf("@%s:%s", helper.bridge.Config.AppService.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), helper.bridge.Config.Encryption.PickleKey, ) @@ -98,7 +96,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { Str("device_id", helper.client.DeviceID.String()). Msg("Logged in as bridge bot") helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, helper.bridge.StateStore) - helper.mach.DisableSharedGroupSessionTracking = true helper.mach.AllowKeyShare = helper.allowKeyShare encryptionConfig := helper.bridge.Config.Encryption @@ -136,19 +133,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } if isExistingDevice { - if !helper.verifyKeysAreOnServer(ctx) { - return nil - } - } else { - err = helper.ShareKeys(ctx) - if err != nil { - return fmt.Errorf("failed to share device keys: %w", err) - } - } - if helper.bridge.Config.Encryption.SelfSign { - if !helper.doSelfSign(ctx) { - os.Exit(34) - } + helper.verifyKeysAreOnServer(ctx) } go helper.resyncEncryptionInfo(context.TODO()) @@ -156,66 +141,30 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return nil } -func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool { - log := zerolog.Ctx(ctx) - hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status") - return false - } - log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status") - keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey) - if !hasKeys || keyInDB == "overwrite" { - if keyInDB != "" && keyInDB != "overwrite" { - log.WithLevel(zerolog.FatalLevel). - Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.") - return false - } - recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx) - if recoveryKey != "" { - helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey) - } - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign") - return false - } - log.Info().Msg("Generated new recovery key and self-signed bot device") - } else if !isVerified { - if keyInDB == "" { - log.WithLevel(zerolog.FatalLevel). - Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.") - return false - } - err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key") - return false - } - log.Info().Msg("Verified bot device with existing recovery key") - } - return true -} - func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) - roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } + roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + if err != nil { + log.Err(err).Msg("Failed to scan rooms for resync") + return + } if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { var evt event.EncryptionEventContent err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event") + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") _, err = helper.store.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync") + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") } } else { maxAge := evt.RotationPeriodMillis @@ -238,9 +187,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL `, maxAge, maxMessages, roomID) if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table") + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") } else { - log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table") + log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") } } } @@ -253,7 +202,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device return &crypto.KeyShareRejectNoResponse } else if device.Trust == id.TrustStateBlacklisted { return &crypto.KeyShareRejectBlacklisted - } else if trustState, _ := helper.mach.ResolveTrustContext(ctx, device); trustState >= cfg.VerificationLevels.Share { + } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { portal, err := helper.bridge.Bridge.GetPortalByMXID(ctx, info.RoomID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal to handle key request") @@ -267,12 +216,11 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle key request") return &crypto.KeyShareRejectNoResponse } else if user == nil { - zerolog.Ctx(ctx).Debug().Msg("Couldn't find user to handle key request") + // TODO return &crypto.KeyShareRejectNoResponse - } else if !user.Permissions.Admin { - zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not admin") - // TODO is in room check? - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "Key sharing for non-admins is not yet implemented"} + } else if true { + // TODO admin check and is in room check + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "Key sharing is not yet implemented in bridgev2"} } zerolog.Ctx(ctx).Debug().Msg("Accepting key request") return nil @@ -286,39 +234,28 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool if err != nil { return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) } else if len(deviceID) > 0 { - helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database") + helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) - - initialDeviceDisplayName := fmt.Sprintf("%s bridge", helper.bridge.Bridge.Network.GetName().DisplayName) - if helper.bridge.Config.Encryption.MSC4190 { - helper.log.Debug().Msg("Creating bot device with MSC4190") - err = client.CreateDeviceMSC4190(ctx, deviceID, initialDeviceDisplayName) - if err != nil { - return nil, deviceID != "", fmt.Errorf("failed to create device for bridge bot: %w", err) - } - helper.store.DeviceID = client.DeviceID - return client, deviceID != "", nil - } - flows, err := client.GetLoginFlows(ctx) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") } - resp, err := client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypeAppservice, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID()), }, - DeviceID: deviceID, - StoreCredentials: true, - InitialDeviceDisplayName: initialDeviceDisplayName, + DeviceID: deviceID, + StoreCredentials: true, + + // TODO find proper bridge name + InitialDeviceDisplayName: "Megabridge", // fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), }) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) @@ -327,7 +264,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool return client, deviceID != "", nil } -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { helper.log.Debug().Msg("Making sure keys are still on server") resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ @@ -340,11 +277,10 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { } device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] if ok && len(device.Keys) > 0 { - return true + return } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") helper.Reset(ctx, false) - return false } func (helper *CryptoHelper) Start() { @@ -439,7 +375,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { return } helper.log.Debug().Err(err). @@ -554,14 +490,14 @@ func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.D func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { everything := []event.Type{{Type: "*"}} return &mautrix.Filter{ - Presence: &mautrix.FilterPart{NotTypes: everything}, - AccountData: &mautrix.FilterPart{NotTypes: everything}, - Room: &mautrix.RoomFilter{ + Presence: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + Room: mautrix.RoomFilter{ IncludeLeave: false, - Ephemeral: &mautrix.FilterPart{NotTypes: everything}, - AccountData: &mautrix.FilterPart{NotTypes: everything}, - State: &mautrix.FilterPart{NotTypes: everything}, - Timeline: &mautrix.FilterPart{NotTypes: everything}, + Ephemeral: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + State: mautrix.FilterPart{NotTypes: everything}, + Timeline: mautrix.FilterPart{NotTypes: everything}, }, } } diff --git a/bridgev2/matrix/cryptoerror.go b/bridgev2/matrix/cryptoerror.go index ea29703a..55110429 100644 --- a/bridgev2/matrix/cryptoerror.go +++ b/bridgev2/matrix/cryptoerror.go @@ -11,8 +11,8 @@ import ( "errors" "fmt" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/matrix/cryptostore.go b/bridgev2/matrix/cryptostore.go index 4c3b5d30..234797a6 100644 --- a/bridgev2/matrix/cryptostore.go +++ b/bridgev2/matrix/cryptostore.go @@ -45,7 +45,7 @@ func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, WHERE room_id=$1 AND (membership='join' OR membership='invite') AND user_id<>$2 - AND user_id NOT LIKE $3 ESCAPE '\' + AND user_id NOT LIKE $3 `, roomID, store.UserID, store.GhostIDFormat) if err != nil { return diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index 0667981a..15af0263 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -13,6 +13,7 @@ import ( "crypto/sha256" "encoding/base64" "fmt" + "net/http" "strings" "maunium.net/go/mautrix" @@ -39,7 +40,7 @@ func (br *Connector) initDirectMedia() error { if err != nil { return fmt.Errorf("failed to initialize media proxy: %w", err) } - br.MediaProxy.RegisterRoutes(br.AS.Router, br.Log.With().Str("component", "media proxy").Logger()) + br.MediaProxy.RegisterRoutes(br.AS.Router) br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed()) dmn.SetUseDirectMedia() br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access") @@ -71,7 +72,7 @@ func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.M return mxc, nil } -func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string, params map[string]string) (response mediaproxy.GetMediaResponse, err error) { +func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (response mediaproxy.GetMediaResponse, err error) { mediaID, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(mediaIDStr, br.Config.DirectMedia.MediaIDPrefix)) if err != nil || !bytes.HasPrefix(mediaID, []byte(MediaIDPrefix)) || len(mediaID) < len(MediaIDPrefix)+MediaIDTruncatedHashLength+1 { return nil, mediaproxy.ErrInvalidMediaIDSyntax @@ -79,8 +80,14 @@ func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string, para receivedHash := mediaID[len(mediaID)-MediaIDTruncatedHashLength:] expectedHash := br.hashMediaID(mediaID[:len(mediaID)-MediaIDTruncatedHashLength]) if !hmac.Equal(receivedHash, expectedHash) { - return nil, mautrix.MNotFound.WithMessage("Invalid checksum in media ID part") + return nil, &mediaproxy.ResponseError{ + Status: http.StatusNotFound, + Data: &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Invalid checksum in media ID part", + }, + } } remoteMediaID := networkid.MediaID(mediaID[len(MediaIDPrefix) : len(mediaID)-MediaIDTruncatedHashLength]) - return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID, params) + return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID) } diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index f7254bd4..e789fa75 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -7,19 +7,14 @@ package matrix import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "io" - "os" "strings" "sync" "time" "github.com/rs/zerolog" - "go.mau.fi/util/fallocate" "go.mau.fi/util/ptr" "golang.org/x/exp/slices" @@ -28,7 +23,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -45,31 +39,23 @@ type ASIntent struct { var _ bridgev2.MatrixAPI = (*ASIntent)(nil) var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) -var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil) func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { if extra == nil { extra = &bridgev2.MatrixSendExtra{} } - if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) { + // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions + if eventType == event.EventRedaction { parsedContent := content.Parsed.(*event.RedactionEventContent) - as.Matrix.AddDoublePuppetValue(content) return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ Reason: parsedContent.Reason, Extra: content.Raw, }) } - if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction { - msgContent, ok := content.Parsed.(*event.MessageEventContent) - if ok { - msgContent.AddPerMessageProfileFallback() - } + if eventType != event.EventReaction && eventType != event.EventRedaction { if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { - if as.Connector.Crypto == nil { - return nil, fmt.Errorf("room is encrypted, but bridge isn't configured to support encryption") - } if as.Matrix.IsCustomPuppet { if extra.Timestamp.IsZero() { as.Matrix.AddDoublePuppetValue(content) @@ -84,27 +70,16 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType eventType = event.EventEncrypted } } - return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) -} - -func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) { - if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { - return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + if extra.Timestamp.IsZero() { + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) + } else { + return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli()) } - if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { - return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) - } else if encrypted && as.Connector.Crypto != nil { - if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil { - return nil, err - } - eventType = event.EventEncrypted - } - return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID}) } func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { - targetContent, ok := content.Parsed.(*event.MemberEventContent) - if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { + targetContent := content.Parsed.(*event.MemberEventContent) + if targetContent.Displayname != "" || targetContent.AvatarURL != "" { return } memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) @@ -139,7 +114,11 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e if eventType == event.StateMember { as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) } - resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()}) + if ts.IsZero() { + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) + } else { + resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) + } if err != nil && eventType == event.StateMember { var httpErr mautrix.HTTPError if errors.As(err, &httpErr) && httpErr.RespError != nil && @@ -227,64 +206,7 @@ func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, return data, nil } -func (as *ASIntent) DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool, callback func(*os.File) error) error { - if file != nil { - uri = file.URL - err := file.PrepareForDecryption() - if err != nil { - return err - } - } - parsedURI, err := uri.Parse() - if err != nil { - return err - } - tempFile, err := os.CreateTemp("", "mautrix-download-*") - if err != nil { - return fmt.Errorf("failed to create temp file: %w", err) - } - defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - }() - resp, err := as.Matrix.Download(ctx, parsedURI) - if err != nil { - return fmt.Errorf("failed to send download request: %w", err) - } - defer resp.Body.Close() - reader := resp.Body - if file != nil { - reader = file.DecryptStream(reader) - } - if resp.ContentLength > 0 { - err = fallocate.Fallocate(tempFile, int(resp.ContentLength)) - if err != nil { - return fmt.Errorf("failed to preallocate file: %w", err) - } - } - _, err = io.Copy(tempFile, reader) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - err = reader.Close() - if err != nil { - return fmt.Errorf("failed to close response body: %w", err) - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return fmt.Errorf("failed to seek to start of temp file: %w", err) - } - err = callback(tempFile) - if err != nil { - return bridgev2.CallbackError{Type: "read", Wrapped: err} - } - return nil -} - func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { - if int64(len(data)) > as.Connector.MediaConfig.UploadSize { - return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(len(data))/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) - } if roomID != "" { var encrypted bool if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { @@ -299,162 +221,12 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by fileName = "" } } - url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ + req := mautrix.ReqUploadMedia{ ContentBytes: data, ContentType: mimeType, FileName: fileName, - }) - return -} - -func (as *ASIntent) UploadMediaStream( - ctx context.Context, - roomID id.RoomID, - size int64, - requireFile bool, - cb bridgev2.FileStreamCallback, -) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { - if size > as.Connector.MediaConfig.UploadSize { - return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) - } - if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold { - var buf bytes.Buffer - res, err := cb(&buf) - if err != nil { - return "", nil, err - } else if res.ReplacementFile != "" { - panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true")) - } - return as.UploadMedia(ctx, roomID, buf.Bytes(), res.FileName, res.MimeType) - } - var tempFile *os.File - tempFile, err = os.CreateTemp("", "mautrix-upload-*") - if err != nil { - err = fmt.Errorf("failed to create temp file: %w", err) - return - } - removeAndClose := func(f *os.File) { - _ = f.Close() - _ = os.Remove(f.Name()) - } - startedAsyncUpload := false - defer func() { - if !startedAsyncUpload { - removeAndClose(tempFile) - } - }() - if size > 0 { - err = fallocate.Fallocate(tempFile, int(size)) - if err != nil { - err = fmt.Errorf("failed to preallocate file: %w", err) - return - } - } - if roomID != "" { - var encrypted bool - if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { - err = fmt.Errorf("failed to check if room is encrypted: %w", err) - return - } else if encrypted { - file = &event.EncryptedFileInfo{ - EncryptedFile: *attachment.NewEncryptedFile(), - } - } - } - var res *bridgev2.FileStreamResult - res, err = cb(tempFile) - if err != nil { - err = bridgev2.CallbackError{Type: "write", Wrapped: err} - return - } - var replFile *os.File - if res.ReplacementFile != "" { - replFile, err = os.OpenFile(res.ReplacementFile, os.O_RDWR, 0) - if err != nil { - err = fmt.Errorf("failed to open replacement file: %w", err) - return - } - defer func() { - if !startedAsyncUpload { - removeAndClose(replFile) - } - }() - } else { - replFile = tempFile - _, err = replFile.Seek(0, io.SeekStart) - if err != nil { - err = fmt.Errorf("failed to seek to start of temp file: %w", err) - return - } - } - if file != nil { - res.FileName = "" - res.MimeType = "application/octet-stream" - err = file.EncryptFile(replFile) - if err != nil { - err = fmt.Errorf("failed to encrypt file: %w", err) - return - } - _, err = replFile.Seek(0, io.SeekStart) - if err != nil { - err = fmt.Errorf("failed to seek to start of temp file after encrypting: %w", err) - return - } - } - info, err := replFile.Stat() - if err != nil { - err = fmt.Errorf("failed to get temp file info: %w", err) - return - } - size = info.Size() - if size > as.Connector.MediaConfig.UploadSize { - return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) - } - req := mautrix.ReqUploadMedia{ - Content: replFile, - ContentLength: size, - ContentType: res.MimeType, - FileName: res.FileName, } if as.Connector.Config.Homeserver.AsyncMedia { - req.DoneCallback = func() { - removeAndClose(replFile) - removeAndClose(tempFile) - } - req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) - startedAsyncUpload = true - var resp *mautrix.RespCreateMXC - resp, err = as.Matrix.UploadAsync(ctx, req) - if resp != nil { - url = resp.ContentURI.CUString() - } - } else { - var resp *mautrix.RespMediaUpload - resp, err = as.Matrix.UploadMedia(ctx, req) - if resp != nil { - url = resp.ContentURI.CUString() - } - } - if file != nil { - file.URL = url - url = "" - } - return -} - -func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileInfo, req mautrix.ReqUploadMedia) (url id.ContentURIString, err error) { - if as.Connector.Config.Homeserver.AsyncMedia { - if req.ContentBytes != nil { - // Prevent too many background uploads at once - err = as.Connector.uploadSema.Acquire(ctx, int64(len(req.ContentBytes))) - if err != nil { - return - } - req.DoneCallback = func() { - as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) - } - } - req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { @@ -486,78 +258,27 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) } -func dataToFields(data any) (map[string]json.RawMessage, error) { - fields, ok := data.(map[string]json.RawMessage) - if ok { - return fields, nil - } - d, err := json.Marshal(data) - if err != nil { - return nil, err - } - d = canonicaljson.CanonicalJSONAssumeValid(d) - err = json.Unmarshal(d, &fields) - return fields, err -} - -func marshalField(val any) json.RawMessage { - data, _ := json.Marshal(val) - if len(data) > 0 && (data[0] == '{' || data[0] == '[') { - return canonicaljson.CanonicalJSONAssumeValid(data) - } - return data -} - -var nullJSON = json.RawMessage("null") - func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { - if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { - return as.Matrix.BeeperUpdateProfile(ctx, data) - } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo { - fields, err := dataToFields(data) - if err != nil { - return fmt.Errorf("failed to marshal fields: %w", err) - } - currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID) - if err != nil { - return fmt.Errorf("failed to get current profile: %w", err) - } - for key, val := range fields { - existing, ok := currentProfile.Extra[key] - if !ok { - if bytes.Equal(val, nullJSON) { - continue - } - err = as.Matrix.SetProfileField(ctx, key, val) - } else if !bytes.Equal(marshalField(existing), val) { - if bytes.Equal(val, nullJSON) { - err = as.Matrix.DeleteProfileField(ctx, key) - } else { - err = as.Matrix.SetProfileField(ctx, key, val) - } - } - if err != nil { - return fmt.Errorf("failed to set profile field %q: %w", key, err) - } - } + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return nil } - return nil + return as.Matrix.BeeperUpdateProfile(ctx, data) } func (as *ASIntent) GetMXID() id.UserID { return as.Matrix.UserID } -func (as *ASIntent) IsDoublePuppet() bool { - return as.Matrix.IsDoublePuppet() +func (as *ASIntent) InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error { + _, err := as.Matrix.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ + Reason: "", + UserID: userID, + }) + return err } -func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...bridgev2.EnsureJoinedParams) error { - var params bridgev2.EnsureJoinedParams - if len(extra) > 0 { - params = extra[0] - } - err := as.Matrix.EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{Via: params.Via}) +func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { + err := as.Matrix.EnsureJoined(ctx, roomID) if err != nil { return err } @@ -583,39 +304,6 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { return content } -func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) { - if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { - // Hungryserv doesn't override the capabilities endpoint nor do room versions - return - } - caps := as.Connector.fetchCapabilities(ctx) - roomVer := req.RoomVersion - if roomVer == "" && caps != nil && caps.RoomVersions != nil { - roomVer = id.RoomVersion(caps.RoomVersions.Default) - } - if roomVer != "" && !roomVer.PrivilegedRoomCreators() { - return - } - creators, _ := req.CreationContent["additional_creators"].([]id.UserID) - creators = append(slices.Clone(creators), as.GetMXID()) - if req.PowerLevelOverride != nil { - for _, creator := range creators { - delete(req.PowerLevelOverride.Users, creator) - } - } - for _, evt := range req.InitialState { - if evt.Type != event.StatePowerLevels { - continue - } - content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) - if ok { - for _, creator := range creators { - delete(content.Users, creator) - } - } - } -} - func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { req.InitialState = append(req.InitialState, &event.Event{ @@ -631,7 +319,6 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) } req.CreationContent["m.federate"] = false } - as.filterCreateRequestForV12(ctx, req) resp, err := as.Matrix.CreateRoom(ctx, req) if err != nil { return "", err @@ -673,19 +360,8 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id. } func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { - if roomID == "" { - return nil - } if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - err := as.Matrix.BeeperDeleteRoom(ctx, roomID) - if err != nil { - return err - } - err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") - } - return nil + return as.Matrix.BeeperDeleteRoom(ctx, roomID) } members, err := as.Matrix.JoinedMembers(ctx, roomID) if err != nil { @@ -711,10 +387,6 @@ func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnl if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to leave room while cleaning up portal") } - err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") - } return nil } @@ -773,23 +445,3 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T }) } } - -func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) { - evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID) - if err != nil { - return nil, err - } - err = evt.Content.ParseRaw(evt.Type) - if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content") - } - - if evt.Type == event.EventEncrypted { - if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { - return nil, errors.New("can't decrypt the event") - } - return as.Connector.Crypto.Decrypt(ctx, evt) - } - - return evt, nil -} diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 954d0ad9..1117fca2 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -17,8 +17,8 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -27,11 +27,6 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { if br.shouldIgnoreEvent(evt) { return } - if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember { - zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events") - br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) - return - } if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) @@ -68,10 +63,6 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) case event.EphemeralEventTyping: typingContent := evt.Content.AsTyping() typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) - case event.BeeperEphemeralEventAIStream: - if br.shouldIgnoreEvent(evt) { - return - } } br.Bridge.QueueMatrixEvent(ctx, evt) } @@ -85,11 +76,6 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() - if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents { - log.Debug().Msg("Dropping event from user with no permission to send events") - br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) - return - } ctx = log.WithContext(ctx) if br.Crypto == nil { br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) @@ -101,18 +87,17 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) decryptionStart := time.Now() decrypted, err := br.Crypto.Decrypt(ctx, evt) decryptionRetryCount := 0 - var errorEventID id.EventID if errors.Is(err, NoSessionFound) { decryptionRetryCount = 1 log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false) + go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false) if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = br.Crypto.Decrypt(ctx, evt) } else { - go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID) + go br.waitLongerForSession(ctx, evt, decryptionStart) return } } @@ -121,18 +106,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true) return } - br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart)) + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart)) } -func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) { +func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug(). Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") - //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + var errorEventID *id.EventID go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -157,7 +142,7 @@ type CommandProcessor interface { } func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) { - err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{{ + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{{ RoomID: evt.RoomID, EventID: evt.ID, EventType: evt.Type, @@ -184,7 +169,7 @@ func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool { } func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { - if br.shouldIgnoreEventFromUser(evt.Sender) && evt.Type != event.StateTombstone { + if br.shouldIgnoreEventFromUser(evt.Sender) { return true } dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey] @@ -235,6 +220,7 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration + decrypted.Mautrix.EventSource |= event.SourceDecrypted br.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != nil && *errorEventID != "" { _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index f5e438de..0f6aa68c 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -66,12 +66,7 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s } else if errors.Is(err, dbutil.ErrForeignTables) { br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { - var noe dbutil.NotOwnedError - if errors.As(err, &noe) && noe.Owner == br.Name { - br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?") - } else { - br.Log.Info().Msg("Sharing the same database with different programs is not supported") - } + br.Log.Info().Msg("Sharing the same database with different programs is not supported") } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { br.Log.Info().Msg("Downgrading the bridge is not supported") } diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go deleted file mode 100644 index 1b4f1467..00000000 --- a/bridgev2/matrix/mxmain/envconfig.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mxmain - -import ( - "fmt" - "iter" - "os" - "reflect" - "strconv" - "strings" - - "go.mau.fi/util/random" -) - -var randomParseFilePrefix = random.String(16) + "READFILE:" - -func parseEnv(prefix string) iter.Seq2[[]string, string] { - return func(yield func([]string, string) bool) { - for _, s := range os.Environ() { - if !strings.HasPrefix(s, prefix) { - continue - } - kv := strings.SplitN(s, "=", 2) - key := strings.TrimPrefix(kv[0], prefix) - value := kv[1] - if strings.HasSuffix(key, "_FILE") { - key = strings.TrimSuffix(key, "_FILE") - value = randomParseFilePrefix + value - } - key = strings.ToLower(key) - if !strings.ContainsRune(key, '.') { - key = strings.ReplaceAll(key, "__", ".") - } - if !yield(strings.Split(key, "."), value) { - return - } - } - } -} - -func reflectYAMLFieldName(f *reflect.StructField) string { - parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2) - fieldName := parts[0] - if fieldName == "-" && len(parts) == 1 { - return "" - } - if fieldName == "" { - return strings.ToLower(f.Name) - } - return fieldName -} - -type reflectGetResult struct { - val reflect.Value - valKind reflect.Kind - remainingPath []string -} - -func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) { - if len(path) == 0 { - return &reflectGetResult{val: rv, valKind: rv.Kind()}, true - } - if rv.Kind() == reflect.Ptr { - rv = rv.Elem() - } - switch rv.Kind() { - case reflect.Map: - return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true - case reflect.Struct: - fields := reflect.VisibleFields(rv.Type()) - for _, field := range fields { - fieldName := reflectYAMLFieldName(&field) - if fieldName != "" && fieldName == path[0] { - return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:]) - } - } - default: - } - return nil, false -} - -func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) { - if len(path) > 0 && path[0] == "network" { - return reflectGetYAML(network, path[1:]) - } - return reflectGetYAML(main, path) -} - -func formatKeyString(key []string) string { - return strings.Join(key, "->") -} - -func UpdateConfigFromEnv(cfg, networkData any, prefix string) error { - cfgVal := reflect.ValueOf(cfg) - networkVal := reflect.ValueOf(networkData) - for key, value := range parseEnv(prefix) { - field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key) - if !ok { - return fmt.Errorf("%s not found", formatKeyString(key)) - } - if strings.HasPrefix(value, randomParseFilePrefix) { - filepath := strings.TrimPrefix(value, randomParseFilePrefix) - fileData, err := os.ReadFile(filepath) - if err != nil { - return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err) - } - value = strings.TrimSpace(string(fileData)) - } - var parsedVal any - var err error - switch field.valKind { - case reflect.String: - parsedVal = value - case reflect.Bool: - parsedVal, err = strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - parsedVal, err = strconv.ParseInt(value, 10, 64) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - parsedVal, err = strconv.ParseUint(value, 10, 64) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - case reflect.Float32, reflect.Float64: - parsedVal, err = strconv.ParseFloat(value, 64) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - default: - return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key)) - } - if field.val.Kind() == reflect.Ptr { - if field.val.IsNil() { - field.val.Set(reflect.New(field.val.Type().Elem())) - } - field.val = field.val.Elem() - } - if field.val.Kind() == reflect.Map { - key = key[:len(key)-len(field.remainingPath)] - mapKeyStr := strings.Join(field.remainingPath, ".") - key = append(key, mapKeyStr) - if field.val.Type().Key().Kind() != reflect.String { - return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key)) - } - field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal)) - } else { - field.val.Set(reflect.ValueOf(parsedVal)) - } - } - return nil -} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index ccc81c4b..06ed010f 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -6,55 +6,16 @@ bridge: personal_filtering_spaces: true # Whether the bridge should set names and avatars explicitly for DM portals. # This is only necessary when using clients that don't support MSC4171. - private_chat_portal_meta: true - # Should events be handled asynchronously within portal rooms? - # If true, events may end up being out of order, but slow events won't block other ones. - # This is not yet safe to use. - async_events: false - # Should every user have their own portals rather than sharing them? - # By default, users who are in the same group on the remote network will be - # in the same Matrix room bridged to that group. If this is set to true, - # every user will get their own Matrix room instead. - # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST. - split_portals: false - # Should the bridge resend `m.bridge` events to all portals on startup? - resend_bridge_info: false - # Should `m.bridge` events be sent without a state key? - # By default, the bridge uses a unique key that won't conflict with other bridges. - no_bridge_info_state_key: false - # Should bridge connection status be sent to the management room as `m.notice` events? - # These contain the same data that can be posted to an external HTTP server using homeserver -> status_endpoint. - # Allowed values: none, errors, all - bridge_status_notices: errors - # How long after an unknown error should the bridge attempt a full reconnect? - # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. - unknown_error_auto_reconnect: null - # Maximum number of times to do the auto-reconnect above. - # The counter is per login, but is never reset except on logout and restart. - unknown_error_max_auto_reconnects: 10 + private_chat_portal_meta: false # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false - # Should `m.notice` messages be bridged? - bridge_notices: false # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. tag_only_on_create: true - # List of tags to allow bridging. If empty, no tags will be bridged. - only_bridge_tags: [m.favourite, m.lowpriority] # Should room mute status only be synced when creating the portal? # Like tags, mutes can't currently be synced back to the remote network. mute_only_on_create: true - # Should the bridge check the db to ensure that incoming events haven't been handled before - deduplicate_matrix_messages: false - # Should cross-room reply metadata be bridged? - # Most Matrix clients don't support this and servers may reject such messages too. - cross_room_replies: false - # If a state event fails to bridge, should the bridge revert any state changes made by that event? - revert_failed_state_changes: false - # In portals with no relay set, should Matrix users be kicked if they're - # not logged into an account that's in the remote chat? - kick_matrix_users: true # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: @@ -187,12 +148,8 @@ homeserver: # Changing these values requires regeneration of the registration (except when noted otherwise) appservice: # The address that the homeserver can use to connect to this appservice. - # Like the homeserver address, a local non-https address is recommended when the bridge is on the same machine. - # If the bridge is elsewhere, you must secure the connection yourself (e.g. with https or wireguard) - # If you want to use https, you need to use a reverse proxy. The bridge does not have TLS support built in. address: http://localhost:$<> # A public address that external services can use to reach this appservice. - # This is only needed for things like public media. A reverse proxy is generally necessary when using this field. # This value doesn't affect the registration file. public_address: https://bridge.example.com @@ -237,30 +194,17 @@ matrix: # Whether the bridge should send error notices via m.notice events when a message fails to bridge. message_error_notices: true # Whether the bridge should update the m.direct account data event when double puppeting is enabled. - sync_direct_chat_list: true + sync_direct_chat_list: false # Whether created rooms should have federation enabled. If false, created portal rooms # will never be federated. Changing this option requires recreating rooms. federate_rooms: true - # The threshold as bytes after which the bridge should roundtrip uploads via the disk - # rather than keeping the whole file in memory. - upload_file_threshold: 5242880 - # Should the bridge set additional custom profile info for ghosts? - # This can make a lot of requests, as there's no batch profile update endpoint. - ghost_extra_profile_info: false - -# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. -analytics: - # API key to send with tracking requests. Tracking is disabled if this is null. - token: null - # Address to send tracking requests to. - url: https://api.segment.io/v1/track - # Optional user ID for tracking events. If null, defaults to using Matrix user ID. - user_id: null # Settings for provisioning API provisioning: + # Prefix for the provisioning API paths. + prefix: /_matrix/provision # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, - # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters. + # or if set to "disable", the provisioning API will be disabled. shared_secret: generate # Whether to allow provisioning API requests to be authed using Matrix access tokens. # This follows the same rules as double puppeting to determine which server to contact to check the token, @@ -268,9 +212,6 @@ provisioning: allow_matrix_auth: true # Enable debug API at /debug with provisioning authentication. debug_endpoints: false - # Enable session transfers between bridges. Note that this only validates Matrix or shared secret - # auth before passing live network client credentials down in the response. - enable_session_transfers: false # Some networks require publicly accessible media download links (e.g. for user avatars when using Discord webhooks). # These settings control whether the bridge will provide such public media access. @@ -286,14 +227,6 @@ public_media: expiry: 0 # Length of hash to use for public media URLs. Must be between 0 and 32. hash_length: 32 - # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. - # If you change this, you must configure your reverse proxy to rewrite the path accordingly. - path_prefix: /_mautrix/publicmedia - # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs? - # If false, the generated URLs will just have the MXC URI and a HMAC signature. - # The hash_length field will be used to decide the length of the generated URL. - # This also allows invalidating URLs by deleting the database entry. - use_database: false # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html @@ -375,21 +308,9 @@ encryption: default: false # Whether to require all messages to be encrypted and drop any unencrypted messages. require: false - # Whether to use MSC3202/MSC4203 instead of /sync long polling for receiving encryption-related data. + # Whether to use MSC2409/MSC3202 instead of /sync long polling for receiving encryption-related data. # This option is not yet compatible with standard Matrix servers like Synapse and should not be used. - # Changing this option requires updating the appservice registration file. appservice: false - # Whether to use MSC4190 instead of appservice login to create the bridge bot device. - # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. - # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). - # Changing this option requires updating the appservice registration file. - msc4190: false - # Whether to encrypt reactions and reply metadata as per MSC4392. - msc4392: false - # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? - # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. - # The generated recovery key will be saved in the kv_store table under `recovery_key`. - self_sign: false # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. # You must use a client that supports requesting keys from other users to use this feature. allow_key_sharing: true @@ -452,16 +373,6 @@ encryption: # You should not enable this option unless you understand all the implications. disable_device_change_key_rotation: false -# Prefix for environment variables. All variables with this prefix must map to valid config fields. -# Nesting in variable names is represented with a dot (.). -# If there are no dots in the name, two underscores (__) are replaced with a dot. -# -# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token. -# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily. -# -# If this is null, reading config fields from environment will be disabled. -env_config_prefix: null - # Logging config. See https://github.com/tulir/zeroconfig for details. logging: min_level: debug diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 97cdeddf..32556de1 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -14,7 +14,6 @@ import ( "fmt" "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridgev2" @@ -24,25 +23,9 @@ import ( "maunium.net/go/mautrix/id" ) -func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery string, newDBVersion int, otherTable dbutil.UpgradeTable, otherTableName string, otherNewVersion int) func(ctx context.Context) error { +func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { return func(ctx context.Context) error { - // Unique constraints must have globally unique names on postgres, and renaming the table doesn't rename them, - // so just drop the ones that may conflict with the new schema. - if br.DB.Dialect == dbutil.Postgres { - _, err := br.DB.Exec(ctx, "ALTER TABLE message DROP CONSTRAINT IF EXISTS message_mxid_unique") - if err != nil { - return fmt.Errorf("failed to drop potentially conflicting constraint on message: %w", err) - } - _, err = br.DB.Exec(ctx, "ALTER TABLE reaction DROP CONSTRAINT IF EXISTS reaction_mxid_unique") - if err != nil { - return fmt.Errorf("failed to drop potentially conflicting constraint on reaction: %w", err) - } - } - err := dbutil.DangerousInternalUpgradeVersionTable(ctx, br.DB) - if err != nil { - return err - } - _, err = br.DB.Exec(ctx, renameTablesQuery) + _, err := br.DB.Exec(ctx, renameTablesQuery) if err != nil { return err } @@ -53,22 +36,6 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if upgradesTo < newDBVersion || compat > newDBVersion { return fmt.Errorf("unexpected new database version (%d/c:%d, expected %d)", upgradesTo, compat, newDBVersion) } - if otherTable != nil { - _, err = br.DB.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", otherTableName)) - if err != nil { - return err - } - otherUpgradesTo, otherCompat, err := otherTable[0].DangerouslyRun(ctx, br.DB) - if err != nil { - return err - } else if otherUpgradesTo < otherNewVersion || otherCompat > otherNewVersion { - return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, otherUpgradesTo, otherCompat, otherNewVersion) - } - _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), otherUpgradesTo, otherCompat) - if err != nil { - return err - } - } copyDataQuery, err = br.DB.Internals().FilterSQLUpgrade(bytes.Split([]byte(copyDataQuery), []byte("\n"))) if err != nil { return err @@ -77,19 +44,11 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if err != nil { return err } - _, err = br.DB.Exec(ctx, "DELETE FROM database_owner") + _, err = br.DB.Exec(ctx, "UPDATE database_owner SET owner = $1 WHERE key = 0", br.DB.Owner) if err != nil { return err } - _, err = br.DB.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", br.DB.Owner) - if err != nil { - return err - } - _, err = br.DB.Exec(ctx, "DELETE FROM version") - if err != nil { - return err - } - _, err = br.DB.Exec(ctx, "INSERT INTO version (version, compat) VALUES ($1, $2)", upgradesTo, compat) + _, err = br.DB.Exec(ctx, "UPDATE version SET version = $1, compat = $2", upgradesTo, compat) if err != nil { return err } @@ -102,17 +61,7 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa } } -func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { - return br.LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery, newDBVersion, nil, "", 0) -} - -func (br *BridgeMain) CheckLegacyDB( - expectedVersion int, - minBridgeVersion, - firstMegaVersion string, - migrator func(context.Context) error, - transaction bool, -) { +func (br *BridgeMain) CheckLegacyDB(expectedVersion int, minBridgeVersion, firstMegaVersion string, migrator func(context.Context) error, transaction bool) { log := br.Log.With().Str("action", "migrate legacy db").Logger() ctx := log.WithContext(context.Background()) exists, err := br.DB.TableExists(ctx, "database_owner") @@ -123,7 +72,7 @@ func (br *BridgeMain) CheckLegacyDB( return } var owner string - err = br.DB.QueryRow(ctx, "SELECT owner FROM database_owner LIMIT 1").Scan(&owner) + err = br.DB.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner) if err != nil && !errors.Is(err, sql.ErrNoRows) { log.Err(err).Msg("Failed to get database owner") return @@ -135,10 +84,7 @@ func (br *BridgeMain) CheckLegacyDB( } var dbVersion int err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) - if err != nil { - log.Fatal().Err(err).Msg("Failed to get database version") - return - } else if dbVersion < expectedVersion { + if dbVersion < expectedVersion { log.Fatal(). Int("expected_version", expectedVersion). Int("version", dbVersion). @@ -211,46 +157,31 @@ func (br *BridgeMain) postMigrateDMPortal(ctx context.Context, portal *bridgev2. } func (br *BridgeMain) PostMigrate(ctx context.Context) error { - log := br.Log.With().Str("action", "post-migrate").Logger() wasMigrated, err := br.DB.TableExists(ctx, "database_was_migrated") if err != nil { return fmt.Errorf("failed to check if database_was_migrated table exists: %w", err) } else if !wasMigrated { return nil } - log.Info().Msg("Doing post-migration updates to Matrix rooms") + zerolog.Ctx(ctx).Info().Msg("Doing post-migration updates to Matrix rooms") portals, err := br.Bridge.GetAllPortalsWithMXID(ctx) if err != nil { return fmt.Errorf("failed to get all portals: %w", err) } for _, portal := range portals { - log := log.With(). - Stringer("room_id", portal.MXID). - Object("portal_key", portal.PortalKey). - Str("room_type", string(portal.RoomType)). - Logger() - log.Debug().Msg("Migrating portal") - if br.PostMigratePortal != nil { - err = br.PostMigratePortal(ctx, portal) + switch portal.RoomType { + case database.RoomTypeDM: + err = br.postMigrateDMPortal(ctx, portal) if err != nil { - log.Err(err).Msg("Failed to run post-migrate portal hook") - continue - } - } else { - switch portal.RoomType { - case database.RoomTypeDM: - err = br.postMigrateDMPortal(ctx, portal) - if err != nil { - return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) - } + return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) } } _, err = br.Matrix.Bot.SendStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.ElementFunctionalMembersContent{ ServiceMembers: []id.UserID{br.Matrix.Bot.UserID}, }) if err != nil { - log.Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") + zerolog.Ctx(ctx).Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") } } @@ -258,6 +189,6 @@ func (br *BridgeMain) PostMigrate(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to drop database_was_migrated table: %w", err) } - log.Info().Msg("Post-migration updates complete") + zerolog.Ctx(ctx).Info().Msg("Post-migration updates complete") return nil } diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 1e8b51d1..af0868bf 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -26,7 +26,6 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exerrors" "go.mau.fi/util/exzerolog" - "go.mau.fi/util/progver" "gopkg.in/yaml.v3" flag "maunium.net/go/mauflag" @@ -63,18 +62,11 @@ type BridgeMain struct { // git tag to see if the built version is the release or a dev build. // You can either bump this right after a release or right before, as long as it matches on the release commit. Version string - // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning, - // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch. - SemCalVer bool // PostInit is a function that will be called after the bridge has been initialized but before it is started. PostInit func() PostStart func() - // PostMigratePortal is a function that will be called during a legacy - // migration for each portal. - PostMigratePortal func(context.Context, *bridgev2.Portal) error - // Connector is the network connector for the bridge. Connector bridgev2.NetworkConnector @@ -90,7 +82,11 @@ type BridgeMain struct { RegistrationPath string SaveConfig bool - ver progver.ProgramVersion + baseVersion string + commit string + LinkifiedVersion string + VersionDesc string + BuildTime time.Time AdditionalShortFlags string AdditionalLongFlags string @@ -99,7 +95,14 @@ type BridgeMain struct { } type VersionJSONOutput struct { - progver.ProgramVersion + Name string + URL string + + Version string + IsRelease bool + Commit string + FormattedVersion string + BuildTime time.Time OS string Arch string @@ -140,11 +143,18 @@ func (br *BridgeMain) PreInit() { flag.PrintHelp() os.Exit(0) } else if *version { - fmt.Println(br.ver.VersionDescription) + fmt.Println(br.VersionDesc) os.Exit(0) } else if *versionJSON { output := VersionJSONOutput{ - ProgramVersion: br.ver, + URL: br.URL, + Name: br.Name, + + Version: br.baseVersion, + IsRelease: br.Version == br.baseVersion, + Commit: br.commit, + FormattedVersion: br.Version, + BuildTime: br.BuildTime, OS: runtime.GOOS, Arch: runtime.GOARCH, @@ -226,8 +236,8 @@ func (br *BridgeMain) Init() { br.Log.Info(). Str("name", br.Name). - Str("version", br.ver.FormattedVersion). - Time("built_at", br.ver.BuildTime). + Str("version", br.Version). + Time("built_at", br.BuildTime). Str("go_version", runtime.Version()). Msg("Initializing bridge") @@ -241,7 +251,7 @@ func (br *BridgeMain) Init() { br.Matrix.AS.DoublePuppetValue = br.Name br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{ Func: func(ce *commands.Event) { - ce.Reply(br.ver.MarkdownDescription()) + ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123)) }, Name: "version", Help: commands.HelpMeta{ @@ -257,10 +267,6 @@ func (br *BridgeMain) Init() { func (br *BridgeMain) initDB() { br.Log.Debug().Msg("Initializing database connection") dbConfig := br.Config.Database - if dbConfig.Type == "sqlite3" { - br.Log.WithLevel(zerolog.FatalLevel).Msg("Invalid database type sqlite3. Use sqlite3-fk-wal instead.") - os.Exit(14) - } if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { var fixedExampleURI string if !strings.HasPrefix(dbConfig.URI, "file:") { @@ -300,7 +306,7 @@ func (br *BridgeMain) validateConfig() error { case br.Config.AppService.HSToken == "This value is generated when generating the registration": return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") case br.Config.Database.URI == "postgres://user:password@host/database?sslmode=disable": - return errors.New("database.uri not configured") + return errors.New("appservice.database not configured") case !br.Config.Bridge.Permissions.IsConfigured(): return errors.New("bridge.permissions not configured") case !strings.Contains(br.Config.AppService.FormatUsername("1234567890"), "1234567890"): @@ -354,21 +360,13 @@ func (br *BridgeMain) LoadConfig() { } } cfg.Bridge.Backfill = cfg.Backfill - if cfg.EnvConfigPrefix != "" { - err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err) - os.Exit(10) - } - } br.Config = &cfg } // Start starts the bridge after everything has been initialized. // This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Start() { - ctx := br.Log.WithContext(context.Background()) - err := br.Bridge.StartConnectors(ctx) + err := br.Bridge.StartConnectors() if err != nil { var dbUpgradeErr bridgev2.DBUpgradeError if errors.As(err, &dbUpgradeErr) { @@ -377,15 +375,14 @@ func (br *BridgeMain) Start() { br.Log.Fatal().Err(err).Msg("Failed to start bridge") } } - err = br.PostMigrate(ctx) + err = br.PostMigrate(br.Log.WithContext(context.Background())) if err != nil { br.Log.Fatal().Err(err).Msg("Failed to run post-migration updates") } - err = br.Bridge.StartLogins(ctx) + err = br.Bridge.StartLogins() if err != nil { br.Log.Fatal().Err(err).Msg("Failed to start existing user logins") } - br.Bridge.PostStart(ctx) if br.PostStart != nil { br.PostStart() } @@ -397,10 +394,8 @@ func (br *BridgeMain) WaitForInterrupt() int { signal.Notify(c, os.Interrupt, syscall.SIGTERM) select { case <-c: - br.Log.Info().Msg("Interrupt signal received from OS") return 0 case exitCode := <-br.manualStop: - br.Log.Info().Msg("Internal stop signal received") return exitCode } } @@ -414,7 +409,7 @@ func (br *BridgeMain) TriggerStop(exitCode int) { // Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Stop() { - br.Bridge.StopWithTimeout(5 * time.Second) + br.Bridge.Stop() } // InitVersion formats the bridge version and build time nicely for things like @@ -439,12 +434,42 @@ func (br *BridgeMain) Stop() { // // (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { - br.ver = progver.ProgramVersion{ - Name: br.Name, - URL: br.URL, - BaseVersion: br.Version, - SemCalVer: br.SemCalVer, - }.Init(tag, commit, rawBuildTime) - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent) - br.Version = br.ver.FormattedVersion + br.baseVersion = br.Version + if len(tag) > 0 && tag[0] == 'v' { + tag = tag[1:] + } + if tag != br.Version { + suffix := "" + if !strings.HasSuffix(br.Version, "+dev") { + suffix = "+dev" + } + if len(commit) > 8 { + br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) + } else { + br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) + } + } + + br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) + if tag == br.Version { + br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) + } else if len(commit) > 8 { + br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) + } + var buildTime time.Time + if rawBuildTime != "unknown" { + buildTime, _ = time.Parse(time.RFC3339, rawBuildTime) + } + var builtWith string + if buildTime.IsZero() { + rawBuildTime = "unknown" + builtWith = runtime.Version() + } else { + rawBuildTime = buildTime.Format(time.RFC1123) + builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version()) + } + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) + br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith) + br.commit = commit + br.BuildTime = buildTime } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 243b91da..107837ef 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -12,26 +12,21 @@ import ( "errors" "fmt" "net/http" - "net/http/pprof" "strings" "sync" "time" + "github.com/gorilla/mux" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/exstrings" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/provisionutil" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/federation" "maunium.net/go/mautrix/id" ) @@ -42,7 +37,7 @@ type matrixAuthCacheEntry struct { } type ProvisioningAPI struct { - Router *http.ServeMux + Router *mux.Router br *Connector log zerolog.Logger @@ -55,20 +50,6 @@ type ProvisioningAPI struct { matrixAuthCache map[string]matrixAuthCacheEntry matrixAuthCacheLock sync.Mutex - - // Set for a given login once credentials have been exported, once in this state the finish - // API is available which will call logout on the client in question. - sessionTransfers map[networkid.UserLoginID]struct{} - sessionTransfersLock sync.Mutex - - // GetAuthFromRequest is a custom function for getting the auth token from - // the request if the Authorization header is not present. - GetAuthFromRequest func(r *http.Request) string - - // GetUserIDFromRequest is a custom function for getting the user ID to - // authenticate as instead of using the user ID provided in the query - // parameter. - GetUserIDFromRequest func(r *http.Request) id.UserID } type ProvLogin struct { @@ -85,84 +66,78 @@ const ( provisioningUserKey provisioningContextKey = iota provisioningUserLoginKey provisioningLoginProcessKey - ProvisioningKeyRequest ) func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } -func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { +func (prov *ProvisioningAPI) GetRouter() *mux.Router { return prov.Router } -func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI { +type IProvisioningAPI interface { + GetRouter() *mux.Router + GetUser(r *http.Request) *bridgev2.User +} + +func (br *Connector) GetProvisioning() IProvisioningAPI { return br.Provisioning } func (prov *ProvisioningAPI) Init() { prov.matrixAuthCache = make(map[string]matrixAuthCacheEntry) prov.logins = make(map[string]*ProvLogin) - prov.sessionTransfers = make(map[networkid.UserLoginID]struct{}) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - prov.fedClient = federation.NewClient("", nil, nil) + prov.fedClient = federation.NewClient("", nil) prov.fedClient.HTTP.Timeout = 20 * time.Second tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second - prov.Router = http.NewServeMux() - prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) - prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities) - prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) - prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) - prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep) - prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout) - prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins) - prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList) - prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) - prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) - prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) - prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup) - - if prov.br.Config.Provisioning.EnableSessionTransfers { - prov.log.Debug().Msg("Enabling session transfer API") - prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer) - prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer) - } + prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() + prov.Router.Use(hlog.NewHandler(prov.log)) + prov.Router.Use(corsMiddleware) + prov.Router.Use(requestlog.AccessLogger(false)) + prov.Router.Use(prov.AuthMiddleware) + prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) + prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) + prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait) + prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) + prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) + prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) + prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) + prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) + prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") - debugRouter := http.NewServeMux() - debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline) - debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile) - debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol) - debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace) - debugRouter.HandleFunc("/pprof/", pprof.Index) - prov.br.AS.Router.Handle("/debug/", exhttp.ApplyMiddleware( - debugRouter, - exhttp.StripPrefix("/debug"), - hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), - requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), - prov.DebugAuthMiddleware, - )) + r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() + r.Use(prov.AuthMiddleware) + r.PathPrefix("/pprof").Handler(http.DefaultServeMux) } +} - errorBodies := exhttp.ErrorBodies{ - NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), - MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), - } - prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware( - prov.Router, - exhttp.StripPrefix("/_matrix/provision"), - hlog.NewHandler(prov.log), - hlog.RequestIDHandler("request_id", "Request-Id"), - exhttp.CORSMiddleware, - requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), - exhttp.HandleErrors(errorBodies), - prov.AuthMiddleware, - )) +func corsMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + handler.ServeHTTP(w, r) + }) +} + +func jsonResponse(w http.ResponseWriter, status int, response any) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(response) } func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { @@ -206,46 +181,18 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } -func disabledAuth(w http.ResponseWriter, r *http.Request) { - mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w) -} - -func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { - secret := prov.br.Config.Provisioning.SharedSecret - if len(secret) < 16 { - return http.HandlerFunc(disabledAuth) - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - if auth == "" { - mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) - } else if !exstrings.ConstantTimeEqual(auth, secret) { - mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) - } else { - h.ServeHTTP(w, r) - } - }) -} - func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { - secret := prov.br.Config.Provisioning.SharedSecret - if len(secret) < 16 { - return http.HandlerFunc(disabledAuth) - } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") - if auth == "" && prov.GetAuthFromRequest != nil { - auth = prov.GetAuthFromRequest(r) - } if auth == "" { - mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ + Err: "Missing auth token", + ErrCode: mautrix.MMissingToken.ErrCode, + }) return } userID := id.UserID(r.URL.Query().Get("user_id")) - if userID == "" && prov.GetUserIDFromRequest != nil { - userID = prov.GetUserIDFromRequest(r) - } - if !exstrings.ConstantTimeEqual(auth, secret) { + if auth != prov.br.Config.Provisioning.SharedSecret { var err error if strings.HasPrefix(auth, "openid:") { err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) @@ -255,25 +202,66 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). Msg("Provisioning API request contained invalid auth") - mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ + Err: "Invalid auth token", + ErrCode: mautrix.MUnknownToken.ErrCode, + }) return } } user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") - mautrix.MUnknown.WithMessage("Failed to get user").Write(w) + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to get user", + ErrCode: "M_UNKNOWN", + }) return } // TODO handle user being nil? - // TODO per-endpoint permissions? - if !user.Permissions.Login { - mautrix.MForbidden.WithMessage("User does not have login permissions").Write(w) - return - } - ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) - ctx = context.WithValue(ctx, provisioningUserKey, user) + ctx := context.WithValue(r.Context(), provisioningUserKey, user) + if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { + prov.loginsLock.RLock() + login, ok := prov.logins[loginID] + prov.loginsLock.RUnlock() + if !ok { + zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return + } + login.Lock.Lock() + // This will only unlock after the handler runs + defer login.Lock.Unlock() + stepID := mux.Vars(r)["stepID"] + if login.NextStep.StepID != stepID { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_id", stepID). + Str("expected_step_id", login.NextStep.StepID). + Msg("Step ID does not match") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Step ID does not match", + ErrCode: mautrix.MBadState.ErrCode, + }) + return + } + stepType := mux.Vars(r)["stepType"] + if login.NextStep.Type != bridgev2.LoginStepType(stepType) { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_type", stepType). + Str("expected_step_type", string(login.NextStep.Type)). + Msg("Step type does not match") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Step type does not match", + ErrCode: mautrix.MBadState.ErrCode, + }) + return + } + ctx = context.WithValue(r.Context(), provisioningLoginProcessKey, login) + } h.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -316,7 +304,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { CommandPrefix: prov.br.Config.Bridge.CommandPrefix, ManagementRoom: user.ManagementRoom, } - logins := user.GetUserLogins() + logins := user.GetCachedUserLogins() resp.Logins = make([]RespWhoamiLogin, len(logins)) for i, login := range logins { prevState := login.BridgeState.GetPrevUnsent() @@ -324,7 +312,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { prevState.UserID = "" prevState.RemoteID = "" prevState.RemoteName = "" - prevState.RemoteProfile = status.RemoteProfile{} + prevState.RemoteProfile = nil resp.Logins[i] = RespWhoamiLogin{ StateEvent: prevState.StateEvent, StateTS: prevState.Timestamp, @@ -338,7 +326,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { SpaceRoom: login.SpaceRoom, } } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + jsonResponse(w, http.StatusOK, resp) } type RespLoginFlows struct { @@ -351,47 +339,30 @@ type RespSubmitLogin struct { } func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) { - exhttp.WriteJSONResponse(w, http.StatusOK, &RespLoginFlows{ + jsonResponse(w, http.StatusOK, &RespLoginFlows{ Flows: prov.net.GetLoginFlows(), }) } -func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) { - exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning) -} - -var ErrNilStep = errors.New("bridge returned nil step with no error") -var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"} - func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) if failed { return } - user := prov.GetUser(r) - if overrideLogin == nil && user.HasTooManyLogins() { - ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w) - return - } - login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID")) + login, err := prov.net.CreateLogin( + r.Context(), + prov.GetUser(r), + mux.Vars(r)["flowID"], + ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") - RespondWithError(w, err, "Internal error creating login process") + respondMaybeCustomError(w, err, "Internal error creating login process") return } - var firstStep *bridgev2.LoginStep - overridable, ok := login.(bridgev2.LoginProcessWithOverride) - if ok && overrideLogin != nil { - firstStep, err = overridable.StartWithOverride(r.Context(), overrideLogin) - } else { - firstStep, err = login.Start(r.Context()) - } - if err == nil && firstStep == nil { - err = ErrNilStep - } + firstStep, err := login.Start(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") - RespondWithError(w, err, "Internal error starting login") + respondMaybeCustomError(w, err, "Internal error starting login") return } loginID := xid.New().String() @@ -403,18 +374,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() - zerolog.Ctx(r.Context()).Info(). - Any("first_step", firstStep). - Msg("Created login process") - exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) + jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { - zerolog.Ctx(ctx).Info(). - Str("step_id", step.StepID). - Str("user_login_id", string(step.CompleteParams.UserLoginID)). - Msg("Login completed successfully") - prov.deleteLogin(login, false) if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { return } @@ -428,67 +391,15 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } -func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) { - if cancel { - login.Process.Cancel() - } - prov.loginsLock.Lock() - delete(prov.logins, login.ID) - prov.loginsLock.Unlock() -} - -func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { - loginID := r.PathValue("loginProcessID") - prov.loginsLock.RLock() - login, ok := prov.logins[loginID] - prov.loginsLock.RUnlock() - if !ok { - zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") - mautrix.MNotFound.WithMessage("Login not found").Write(w) - return - } - login.Lock.Lock() - // This will only unlock after the handler runs - defer login.Lock.Unlock() - stepID := r.PathValue("stepID") - if login.NextStep.StepID != stepID { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_id", stepID). - Str("expected_step_id", login.NextStep.StepID). - Msg("Step ID does not match") - mautrix.MBadState.WithMessage("Step ID does not match").Write(w) - return - } - stepType := r.PathValue("stepType") - if login.NextStep.Type != bridgev2.LoginStepType(stepType) { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_type", stepType). - Str("expected_step_type", string(login.NextStep.Type)). - Msg("Step type does not match") - mautrix.MBadState.WithMessage("Step type does not match").Write(w) - return - } - ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login) - r = r.WithContext(ctx) - switch bridgev2.LoginStepType(r.PathValue("stepType")) { - case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies: - prov.PostLoginSubmitInput(w, r) - case bridgev2.LoginStepTypeDisplayAndWait: - prov.PostLoginWait(w, r) - case bridgev2.LoginStepTypeComplete: - fallthrough - default: - // This is probably impossible because of the above check that the next step type matches the request. - mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w) - } -} - func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { var params map[string]string err := json.NewDecoder(r.Body).Decode(¶ms) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Failed to decode request body", + ErrCode: mautrix.MNotJSON.ErrCode, + }) return } login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) @@ -501,48 +412,39 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http default: panic("Impossible state") } - if err == nil && nextStep == nil { - err = ErrNilStep - } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") - RespondWithError(w, err, "Internal error submitting input") - prov.deleteLogin(login, true) + respondMaybeCustomError(w, err, "Internal error submitting input") return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) - } else { - zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } - exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) - if err == nil && nextStep == nil { - err = ErrNilStep - } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") - RespondWithError(w, err, "Internal error waiting for login") - prov.deleteLogin(login, true) + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to wait", + ErrCode: "M_UNKNOWN", + }) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) - } else { - zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } - exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - userLoginID := networkid.UserLoginID(r.PathValue("loginID")) + userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) if userLoginID == "all" { for { login := user.GetDefaultLogin() @@ -554,12 +456,15 @@ func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) } else { userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != user.MXID { - mautrix.MNotFound.WithMessage("Login not found").Write(w) + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) return } userLogin.Logout(r.Context()) } - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) + jsonResponse(w, http.StatusOK, json.RawMessage("{}")) } type RespGetLogins struct { @@ -568,7 +473,7 @@ type RespGetLogins struct { func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) + jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) { @@ -578,21 +483,15 @@ func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r } userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { - hlog.FromRequest(r).Warn(). - Str("login_id", string(userLoginID)). - Msg("Tried to use non-existent login, returning 404") - mautrix.MNotFound.WithMessage("Login not found").Write(w) + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) return nil, true } return userLogin, false } -var ErrNotLoggedIn = mautrix.RespError{ - Err: "Not logged in", - ErrCode: "FI.MAU.NOT_LOGGED_IN", - StatusCode: http.StatusBadRequest, -} - func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { userLogin, failed := prov.GetExplicitLoginForRequest(w, r) if userLogin != nil || failed { @@ -600,23 +499,40 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R } userLogin = prov.GetUser(r).GetDefaultLogin() if userLogin == nil { - ErrNotLoggedIn.Write(w) + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + }) return nil } return userLogin } -type WritableError interface { - Write(w http.ResponseWriter) +func respondMaybeCustomError(w http.ResponseWriter, err error, message string) { + var mautrixRespErr mautrix.RespError + var bv2RespErr bridgev2.RespError + if errors.As(err, &bv2RespErr) { + mautrixRespErr = mautrix.RespError(bv2RespErr) + } else if !errors.As(err, &mautrixRespErr) { + mautrixRespErr = mautrix.RespError{ + Err: message, + ErrCode: "M_UNKNOWN", + StatusCode: http.StatusInternalServerError, + } + } + if mautrixRespErr.StatusCode == 0 { + mautrixRespErr.StatusCode = http.StatusInternalServerError + } + jsonResponse(w, mautrixRespErr.StatusCode, mautrixRespErr) } -func RespondWithError(w http.ResponseWriter, err error, message string) { - var we WritableError - if errors.As(err, &we) { - we.Write(w) - } else { - mautrix.MUnknown.WithMessage(message).Write(w) - } +type RespResolveIdentifier struct { + ID networkid.UserID `json:"id"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` } func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { @@ -624,18 +540,72 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. if login == nil { return } - resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat) - if err != nil { - RespondWithError(w, err, "Internal error resolving identifier") - } else if resp == nil { - mautrix.MNotFound.WithMessage("Identifier not found").Write(w) - } else { - status := http.StatusOK - if resp.JustCreated { - status = http.StatusCreated - } - exhttp.WriteJSONResponse(w, status, resp) + api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "This bridge does not support resolving identifiers", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) + return } + resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") + respondMaybeCustomError(w, err, "Internal error resolving identifier") + return + } else if resp == nil { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Identifier not found", + }) + return + } + apiResp := &RespResolveIdentifier{ + ID: resp.UserID, + } + status := http.StatusOK + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo) + } + apiResp.Name = resp.Ghost.Name + apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.Identifiers = resp.Ghost.Identifiers + apiResp.MXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + apiResp.Name = *resp.UserInfo.Name + } + if resp.Chat != nil { + if resp.Chat.Portal == nil { + resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to get portal", + ErrCode: "M_UNKNOWN", + }) + return + } + } + if createChat && resp.Chat.Portal.MXID == "" { + status = http.StatusCreated + err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to create portal room", + ErrCode: "M_UNKNOWN", + }) + return + } + } + apiResp.DMRoomID = resp.Chat.Portal.MXID + } + jsonResponse(w, status, resp) +} + +type RespGetContactList struct { + Contacts []*RespResolveIdentifier `json:"contacts"` } func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { @@ -643,36 +613,62 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque if login == nil { return } - resp, err := provisionutil.GetContactList(r.Context(), login) + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "This bridge does not support listing contacts", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) + return + } + resp, err := api.GetContactList(r.Context()) if err != nil { - RespondWithError(w, err, "Internal error getting contact list") + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") + respondMaybeCustomError(w, err, "Internal error fetching contact list") return } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) -} - -type ReqSearchUsers struct { - Query string `json:"query"` -} - -func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) { - var req ReqSearchUsers - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) - return + apiResp := &RespGetContactList{ + Contacts: make([]*RespResolveIdentifier, len(resp)), } - login := prov.GetLoginForRequest(w, r) - if login == nil { - return + for i, contact := range resp { + apiContact := &RespResolveIdentifier{ + ID: contact.UserID, + } + apiResp.Contacts[i] = apiContact + if contact.UserInfo != nil { + if contact.UserInfo.Name != nil { + apiContact.Name = *contact.UserInfo.Name + } + if contact.UserInfo.Identifiers != nil { + apiContact.Identifiers = contact.UserInfo.Identifiers + } + } + if contact.Ghost != nil { + if contact.Ghost.Name != "" { + apiContact.Name = contact.Ghost.Name + } + if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Identifiers + } + apiContact.AvatarURL = contact.Ghost.AvatarMXC + apiContact.MXID = contact.Ghost.Intent.GetMXID() + } + if contact.Chat != nil { + if contact.Chat.Portal == nil { + contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), contact.Chat.PortalKey) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to get portal", + ErrCode: "M_UNKNOWN", + }) + return + } + } + apiContact.DMRoomID = contact.Chat.Portal.MXID + } } - resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query) - if err != nil { - RespondWithError(w, err, "Internal error searching users") - return - } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + jsonResponse(w, http.StatusOK, apiResp) } func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { @@ -684,114 +680,12 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request } func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { - var req bridgev2.GroupCreateParams - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) - return - } - req.Type = r.PathValue("type") login := prov.GetLoginForRequest(w, r) if login == nil { return } - resp, err := provisionutil.CreateGroup(r.Context(), login, &req) - if err != nil { - RespondWithError(w, err, "Internal error creating group") - return - } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) -} - -type ReqExportCredentials struct { - RemoteID networkid.UserLoginID `json:"remote_id"` -} - -type RespExportCredentials struct { - Credentials any `json:"credentials"` -} - -func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *http.Request) { - prov.sessionTransfersLock.Lock() - defer prov.sessionTransfersLock.Unlock() - - var req ReqExportCredentials - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) - return - } - - user := prov.GetUser(r) - logins := user.GetUserLogins() - var loginToExport *bridgev2.UserLogin - for _, login := range logins { - if login.ID == req.RemoteID { - loginToExport = login - break - } - } - if loginToExport == nil { - mautrix.MNotFound.WithMessage("No matching user login found").Write(w) - return - } - - client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI) - if !ok { - mautrix.MUnrecognized.WithMessage("This bridge does not support exporting credentials").Write(w) - return - } - - if _, ok := prov.sessionTransfers[loginToExport.ID]; ok { - // Warn, but allow, double exports. This might happen if a client crashes handling creds, - // and should be safe to call multiple times. - zerolog.Ctx(r.Context()).Warn().Msg("Exporting already exported credentials") - } - - // Disconnect now so we don't use the same network session in two places at once - client.Disconnect() - exhttp.WriteJSONResponse(w, http.StatusOK, &RespExportCredentials{ - Credentials: client.ExportCredentials(r.Context()), + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "Creating groups is not yet implemented", + ErrCode: mautrix.MUnrecognized.ErrCode, }) } - -func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) { - prov.sessionTransfersLock.Lock() - defer prov.sessionTransfersLock.Unlock() - - var req ReqExportCredentials - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) - return - } - - user := prov.GetUser(r) - logins := user.GetUserLogins() - var loginToExport *bridgev2.UserLogin - for _, login := range logins { - if login.ID == req.RemoteID { - loginToExport = login - break - } - } - if loginToExport == nil { - mautrix.MNotFound.WithMessage("No matching user login found").Write(w) - return - } else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok { - mautrix.MBadState.WithMessage("No matching credential export found").Write(w) - return - } - - zerolog.Ctx(r.Context()).Info(). - Str("remote_name", string(req.RemoteID)). - Msg("Logging out remote after finishing credential export") - - loginToExport.Client.LogoutRemote(r.Context()) - delete(prov.sessionTransfers, req.RemoteID) - - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) -} diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 26068db4..1daf7b07 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -91,24 +91,6 @@ paths: post: tags: [ auth ] summary: Start a new login process. - description: | - This endpoint starts a new login process, which is used to log into the bridge. - - The basic flow of the entire login, including calling this endpoint, is: - 1. Call `GET /v3/login/flows` to get the list of available flows. - If there's more than one flow, ask the user to pick which one they want to use. - 2. Call this endpoint with the chosen flow ID to start the login. - The first login step will be returned. - 3. Render the information provided in the step. - 4. Call the `/login/step/...` endpoint corresponding to the step type: - * For `user_input` and `cookies`, acquire the requested fields before calling the endpoint. - * For `display_and_wait`, call the endpoint immediately - (as there's nothing to acquire on the client side). - 5. Handle the data returned by the login step endpoint: - * If an error is returned, the login has failed and must be restarted - (from either step 1 or step 2) if the user wants to try again. - * If step type `complete` is returned, the login finished successfully. - * Otherwise, go to step 3 with the new data. operationId: startLogin parameters: - name: login_id @@ -258,51 +240,6 @@ paths: responses: 200: description: Contact list fetched successfully - content: - application/json: - schema: - type: object - properties: - contacts: - type: array - items: - $ref: '#/components/schemas/ResolvedIdentifier' - 401: - $ref: '#/components/responses/Unauthorized' - 404: - $ref: '#/components/responses/LoginNotFound' - 500: - $ref: '#/components/responses/InternalError' - 501: - $ref: '#/components/responses/NotSupported' - /v3/search_users: - post: - tags: [ snc ] - summary: Search for users on the remote network - operationId: searchUsers - parameters: - - $ref: "#/components/parameters/loginID" - requestBody: - content: - application/json: - schema: - type: object - properties: - query: - type: string - description: The search query to send to the remote network - responses: - 200: - description: Search completed successfully - content: - application/json: - schema: - type: object - properties: - results: - type: array - items: - $ref: '#/components/schemas/ResolvedIdentifier' 401: $ref: '#/components/responses/Unauthorized' 404: @@ -361,25 +298,14 @@ paths: $ref: '#/components/responses/InternalError' 501: $ref: '#/components/responses/NotSupported' - /v3/create_group/{type}: + /v3/create_group: post: tags: [ snc ] summary: Create a group chat on the remote network. operationId: createGroup parameters: - $ref: "#/components/parameters/loginID" - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/GroupCreateParams' responses: - 200: - description: Identifier resolved successfully - content: - application/json: - schema: - $ref: '#/components/schemas/CreatedGroup' 401: $ref: '#/components/responses/Unauthorized' 404: @@ -400,7 +326,7 @@ components: - username - meow@example.com loginID: - name: login_id + name: loginID in: query description: An optional explicit login ID to do the action through. required: false @@ -583,74 +509,6 @@ components: description: The Matrix room ID of the direct chat with the user. examples: - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' - GroupCreateParams: - type: object - description: | - Parameters for creating a group chat. - The /capabilities endpoint response must be checked to see which fields are actually allowed. - properties: - type: - type: string - description: The type of group to create. - examples: - - channel - username: - type: string - description: The public username for the created group. - participants: - type: array - description: The users to add to the group initially. - items: - type: string - parent: - type: object - name: - type: object - description: The `m.room.name` event content for the room. - properties: - name: - type: string - avatar: - type: object - description: The `m.room.avatar` event content for the room. - properties: - url: - type: string - format: mxc - topic: - type: object - description: The `m.room.topic` event content for the room. - properties: - topic: - type: string - disappear: - type: object - description: The `com.beeper.disappearing_timer` event content for the room. - properties: - type: - type: string - timer: - type: number - room_id: - type: string - format: matrix_room_id - description: | - An existing Matrix room ID to bridge to. - The other parameters must be already in sync with the room state when using this parameter. - CreatedGroup: - type: object - description: A successfully created group chat. - required: [id, mxid] - properties: - id: - type: string - description: The internal chat ID of the created group. - mxid: - type: string - format: matrix_room_id - description: The Matrix room ID of the portal. - examples: - - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' LoginStep: type: object description: A step in a login process. @@ -714,7 +572,7 @@ components: type: type: string description: The type of field. - enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ] + enum: [ username, phone_number, email, password, 2fa_code, token ] id: type: string description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. @@ -728,53 +586,10 @@ components: description: A more detailed description of the field shown to the user. examples: - Include the country code with a + - default_value: - type: string - description: A default value that the client can pre-fill the field with. pattern: type: string format: regex description: A regular expression that the field value must match. - options: - type: array - description: For fields of type select, the valid options. - items: - type: string - attachments: - type: array - description: A list of media attachments to show the user alongside the form fields. - items: - type: object - description: A media attachment to show the user. - required: [ type, filename, content ] - properties: - type: - type: string - description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported. - enum: [ m.image, m.audio ] - filename: - type: string - description: The filename for the media attachment. - content: - type: string - description: The raw file content for the attachment encoded in base64. - info: - type: object - description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted. - properties: - mimetype: - type: string - description: The MIME type for the media content. - examples: [ image/png, audio/mpeg ] - w: - type: number - description: The width of the media in pixels. Only applicable for images and videos. - h: - type: number - description: The height of the media in pixels. Only applicable for images and videos. - size: - type: number - description: The size of the media content in number of bytes. Strongly recommended to include. - description: Cookie login step required: [ type, cookies ] properties: @@ -793,20 +608,6 @@ components: user_agent: type: string description: An optional user agent that the webview should use. - wait_for_url_pattern: - type: string - description: | - A regex pattern that the URL should match before the client closes the webview. - - The client may submit the login if the user closes the webview after all cookies are collected - even if this URL is not reached, but it should only automatically close the webview after - both cookies and the URL match. - extract_js: - type: string - description: | - A JavaScript snippet that can extract some or all of the fields. - The snippet will evaluate to a promise that resolves when the relevant fields are found. - Fields that are not present in the promise result must be extracted another way. fields: type: array description: The list of cookies or other stored data that must be extracted. diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 82ea8c2b..9db5f442 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,26 +7,18 @@ package matrix import ( - "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/binary" "fmt" "io" - "mime" "net/http" - "net/url" - "slices" - "strings" "time" - "github.com/rs/zerolog" + "github.com/gorilla/mux" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -43,10 +35,7 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) + br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet) return nil } @@ -57,20 +46,6 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] } -func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte { - hasher := hmac.New(sha256.New, br.pubMediaSigKey) - hasher.Write([]byte(pm.MXC.String())) - hasher.Write([]byte(pm.MimeType)) - if pm.Keys != nil { - hasher.Write([]byte(pm.Keys.Version)) - hasher.Write([]byte(pm.Keys.Key.Algorithm)) - hasher.Write([]byte(pm.Keys.Key.Key)) - hasher.Write([]byte(pm.Keys.InitVector)) - hasher.Write([]byte(pm.Keys.Hashes.SHA256)) - } - return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength] -} - func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { var expiresAt []byte if br.Config.PublicMedia.Expiry > 0 { @@ -101,15 +76,16 @@ var proxyHeadersToCopy = []string{ } func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) contentURI := id.ContentURI{ - Homeserver: r.PathValue("server"), - FileID: r.PathValue("mediaID"), + Homeserver: vars["server"], + FileID: vars["mediaID"], } if !contentURI.IsValid() { http.Error(w, "invalid content URI", http.StatusBadRequest) return } - checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum")) + checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) return @@ -120,47 +96,9 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { http.Error(w, "checksum expired", http.StatusGone) return } - br.doProxyMedia(w, r, contentURI, nil, "") -} - -func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) { - if !br.Config.PublicMedia.UseDatabase { - http.Error(w, "public media short links are disabled", http.StatusNotFound) - return - } - log := zerolog.Ctx(r.Context()) - media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID")) - if err != nil { - log.Err(err).Msg("Failed to get public media from database") - http.Error(w, "failed to get media metadata", http.StatusInternalServerError) - return - } else if media == nil { - http.Error(w, "media ID not found", http.StatusNotFound) - return - } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) { - // This is not gone as it can still be refreshed in the DB - http.Error(w, "media expired", http.StatusNotFound) - return - } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil { - http.Error(w, "media keys are malformed", http.StatusInternalServerError) - return - } - br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType) -} - -var safeMimes = []string{ - "text/css", "text/plain", "text/csv", - "application/json", "application/ld+json", - "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", - "video/mp4", "video/webm", "video/ogg", "video/quicktime", - "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", - "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", -} - -func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) { resp, err := br.Bot.Download(r.Context(), contentURI) if err != nil { - zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") http.Error(w, "failed to download media", http.StatusInternalServerError) return } @@ -168,41 +106,11 @@ func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, conten for _, hdr := range proxyHeadersToCopy { w.Header()[hdr] = resp.Header[hdr] } - stream := resp.Body - if encInfo != nil { - if mimeType == "" { - mimeType = "application/octet-stream" - } - contentDisposition := "attachment" - if slices.Contains(safeMimes, mimeType) { - contentDisposition = "inline" - } - dispositionArgs := map[string]string{} - if filename := r.PathValue("filename"); filename != "" { - dispositionArgs["filename"] = filename - } - w.Header().Set("Content-Type", mimeType) - w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs)) - // Note: this won't check the Close result like it should, but it's probably not a big deal here - stream = encInfo.DecryptStream(stream) - } else if filename := r.PathValue("filename"); filename != "" { - contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) - if contentDisposition == "" { - contentDisposition = "attachment" - } - w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{ - "filename": filename, - })) - } w.WriteHeader(http.StatusOK) - _, _ = io.Copy(w, stream) + _, _ = io.Copy(w, resp.Body) } func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { - return br.getPublicMediaAddressWithFileName(contentURI, "") -} - -func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { if br.pubMediaSigKey == nil { return "" } @@ -210,69 +118,11 @@ func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIS if err != nil || !parsed.IsValid() { return "" } - fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_")) - if fileName == ".." { - fileName = "" - } - parts := []string{ + return fmt.Sprintf( + "%s/_mautrix/publicmedia/%s/%s/%s", br.GetPublicAddress(), - strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), parsed.Homeserver, parsed.FileID, base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), - fileName, - } - if fileName == "" { - parts = parts[:len(parts)-1] - } - return strings.Join(parts, "/") -} - -func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) { - if br.pubMediaSigKey == nil { - return "", bridgev2.ErrPublicMediaDisabled - } - if !br.Config.PublicMedia.UseDatabase { - if evt.File != nil { - return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled) - } - return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil - } - mxc := evt.URL - var keys *attachment.EncryptedFile - if evt.File != nil { - mxc = evt.File.URL - keys = &evt.File.EncryptedFile - } - parsedMXC, err := mxc.Parse() - if err != nil { - return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err) - } - pm := &database.PublicMedia{ - MXC: parsedMXC, - Keys: keys, - MimeType: evt.GetInfo().MimeType, - } - if br.Config.PublicMedia.Expiry > 0 { - pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second) - } - pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm)) - err = br.Bridge.DB.PublicMedia.Put(ctx, pm) - if err != nil { - return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err) - } - fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_")) - if fileName == ".." { - fileName = "" - } - parts := []string{ - br.GetPublicAddress(), - strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), - pm.PublicID, - fileName, - } - if fileName == "" { - parts = parts[:len(parts)-1] - } - return strings.Join(parts, "/"), nil + ) } diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go index b498cacd..36b8bca4 100644 --- a/bridgev2/matrix/websocket.go +++ b/bridgev2/matrix/websocket.go @@ -57,11 +57,11 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { addr = br.Config.Homeserver.Address } for { - err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect) + err := br.AS.StartWebsocket(addr, onConnect) if errors.Is(err, appservice.ErrWebsocketManualStop) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { - log.Warn().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") + log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") if br.OnWebsocketReplaced != nil { br.OnWebsocketReplaced() } else { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index be26db49..0628f16d 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,33 +8,26 @@ package bridgev2 import ( "context" - "fmt" - "io" - "net/http" - "os" "time" - "go.mau.fi/util/exhttp" + "github.com/gorilla/mux" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type MatrixCapabilities struct { - AutoJoinInvites bool - BatchSending bool - ArbitraryMemberChange bool - ExtraProfileMeta bool + AutoJoinInvites bool + BatchSending bool } type MatrixConnector interface { Init(*Bridge) Start(ctx context.Context) error - PreStop() Stop() GetCapabilities() *MatrixCapabilities @@ -54,85 +47,29 @@ type MatrixConnector interface { GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) - GenerateDeterministicRoomID(portalKey networkid.PortalKey) id.RoomID GenerateDeterministicEventID(roomID id.RoomID, portalKey networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID ServerName() string } -type MatrixConnectorWithArbitraryRoomState interface { - MatrixConnector - GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) -} - type MatrixConnectorWithServer interface { - MatrixConnector GetPublicAddress() string - GetRouter() *http.ServeMux -} - -type IProvisioningAPI interface { - GetRouter() *http.ServeMux - GetUser(r *http.Request) *User -} - -type MatrixConnectorWithProvisioning interface { - MatrixConnector - GetProvisioning() IProvisioningAPI + GetRouter() *mux.Router } type MatrixConnectorWithPublicMedia interface { - MatrixConnector GetPublicMediaAddress(contentURI id.ContentURIString) string - GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) } type MatrixConnectorWithNameDisambiguation interface { - MatrixConnector IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } -type MatrixConnectorWithBridgeIdentifier interface { - MatrixConnector - GetUniqueBridgeID() string -} - -type MatrixConnectorWithURLPreviews interface { - MatrixConnector - GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) -} - type MatrixConnectorWithPostRoomBridgeHandling interface { - MatrixConnector HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } -type MatrixConnectorWithAnalytics interface { - MatrixConnector - TrackAnalytics(userID id.UserID, event string, properties map[string]any) -} - -type DirectNotificationData struct { - Portal *Portal - Sender *Ghost - MessageID networkid.MessageID - Message string - - FormattedNotification string - FormattedTitle string -} - -type MatrixConnectorWithNotifications interface { - MatrixConnector - DisplayNotification(ctx context.Context, data *DirectNotificationData) -} - -type MatrixConnectorWithHTTPSettings interface { - MatrixConnector - GetHTTPClientSettings() exhttp.ClientSettings -} - type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message @@ -141,48 +78,8 @@ type MatrixSendExtra struct { PartIndex int } -// FileStreamResult is the result of a FileStreamCallback. -type FileStreamResult struct { - // ReplacementFile is the path to a new file that replaces the original file provided to the callback. - // Providing a replacement file is only allowed if the requireFile flag was set for the UploadMediaStream call. - ReplacementFile string - // FileName is the name of the file to be specified when uploading to the server. - // This should be the same as the file name that will be included in the Matrix event (body or filename field). - // If the file gets encrypted, this field will be ignored. - FileName string - // MimeType is the type of field to be specified when uploading to the server. - // This should be the same as the mime type that will be included in the Matrix event (info -> mimetype field). - // If the file gets encrypted, this field will be replaced with application/octet-stream. - MimeType string -} - -// FileStreamCallback is a callback function for file uploads that roundtrip via disk. -// -// The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set. -// -// The return value must be non-nil unless there's an error, and should always include FileName and MimeType. -type FileStreamCallback func(file io.Writer) (*FileStreamResult, error) - -type CallbackError struct { - Type string - Wrapped error -} - -func (ce CallbackError) Error() string { - return fmt.Sprintf("%s callback failed: %s", ce.Type, ce.Wrapped.Error()) -} - -func (ce CallbackError) Unwrap() error { - return ce.Wrapped -} - -type EnsureJoinedParams struct { - Via []string -} - type MatrixAPI interface { GetMXID() id.UserID - IsDoublePuppet() bool SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *MatrixSendExtra) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) @@ -190,9 +87,7 @@ type MatrixAPI interface { MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) - DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool, callback func(*os.File) error) error UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) - UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error @@ -200,26 +95,14 @@ type MatrixAPI interface { CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error - EnsureJoined(ctx context.Context, roomID id.RoomID, params ...EnsureJoinedParams) error + InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error + EnsureJoined(ctx context.Context, roomID id.RoomID) error EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error - - GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) -} - -type StreamOrderReadingMatrixAPI interface { - MatrixAPI - MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error } type MarkAsDMMatrixAPI interface { - MatrixAPI MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } - -type EphemeralSendingMatrixAPI interface { - MatrixAPI - BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) -} diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 75c00cb0..740743f6 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -19,17 +19,17 @@ import ( "maunium.net/go/mautrix/id" ) -func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { +func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) { log := zerolog.Ctx(ctx) // These invites should already be rejected in QueueMatrixEvent if !sender.Permissions.Commands { log.Warn().Msg("Received bot invite from user without permission to send commands") - return EventHandlingResultIgnored + return } err := br.Bot.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to accept invite to room") - return EventHandlingResultFailed + return } log.Debug().Msg("Accepted invite to room as bot") members, err := br.Matrix.GetMembers(ctx, evt.RoomID) @@ -55,7 +55,6 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender log.Err(err).Msg("Failed to send welcome message to room") } } - return EventHandlingResultSuccess } func sendNotice(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { @@ -88,69 +87,39 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, rejectInvite(ctx, evt, intent, "") } -func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) { - if portal.MXID == "" { - return - } - log := zerolog.Ctx(ctx) - existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) - if err != nil { - log.Err(err). - Stringer("old_portal_mxid", portal.MXID). - Msg("Failed to check existing portal members, deleting room") - } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Msg("Inviter has no member event in old portal, deleting room") - } else if targetUserMember.Membership.IsInviteOrJoin() { - return - } else { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Str("membership", string(targetUserMember.Membership)). - Msg("Inviter is not in old portal, deleting room") - } - - if err = portal.RemoveMXID(ctx); err != nil { - log.Err(err).Msg("Failed to delete old portal mxid") - } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { - log.Err(err).Msg("Failed to clean up old portal room") - } -} - -func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { +func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) { ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) validator, ok := br.Network.(IdentifierValidatingNetwork) if ghostID == "" || (ok && !validator.ValidateUserID(ghostID)) { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "Malformed user ID") - return EventHandlingResultIgnored + return } log := zerolog.Ctx(ctx).With(). Str("invitee_network_id", string(ghostID)). Stringer("room_id", evt.RoomID). Logger() // TODO sort in preference order - logins := sender.GetUserLogins() + logins := sender.GetCachedUserLogins() if len(logins) == 0 { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "You're not logged in") - return EventHandlingResultIgnored + return } _, ok = logins[0].Client.(IdentifierResolvingNetworkAPI) if !ok { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "This bridge does not support starting chats") - return EventHandlingResultIgnored + return } invitedGhost, err := br.GetGhostByID(ctx, ghostID) if err != nil { log.Err(err).Msg("Failed to get invited ghost") - return EventHandlingResultFailed + return } err = invitedGhost.Intent.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to accept invite to room") - return EventHandlingResultFailed + return } - var resp *CreateChatResponse + var resp *ResolveIdentifierResponse var sourceLogin *UserLogin // TODO this should somehow lock incoming event processing to avoid race conditions where a new portal room is created // between ResolveIdentifier returning and the portal MXID being updated. @@ -159,23 +128,14 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if !ok { continue } - var resolveResp *ResolveIdentifierResponse - ghostAPI, ok := login.Client.(GhostDMCreatingNetworkAPI) - if ok { - resp, err = ghostAPI.CreateChatWithGhost(ctx, invitedGhost) - } else { - resolveResp, err = api.ResolveIdentifier(ctx, string(ghostID), true) - if resolveResp != nil { - resp = resolveResp.Chat - } - } + resp, err = api.ResolveIdentifier(ctx, string(ghostID), true) if errors.Is(err, ErrResolveIdentifierTryNext) { log.Debug().Err(err).Str("login_id", string(login.ID)).Msg("Failed to resolve identifier, trying next login") continue } else if err != nil { log.Err(err).Msg("Failed to resolve identifier") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to create chat") - return EventHandlingResultFailed + return } else { sourceLogin = login break @@ -184,93 +144,69 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if resp == nil { log.Warn().Msg("No login could resolve the identifier") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create chat via any login") - return EventHandlingResultFailed + return } - portal := resp.Portal + portal := resp.Chat.Portal if portal == nil { - portal, err = br.GetPortalByKey(ctx, resp.PortalKey) + portal, err = br.GetPortalByKey(ctx, resp.Chat.PortalKey) if err != nil { log.Err(err).Msg("Failed to get portal by key") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create portal entry") - return EventHandlingResultFailed + return } } - portal.CleanupOrphanedDM(ctx, sender.MXID) err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to invite bridge bot") - return EventHandlingResultFailed + return } err = br.Bot.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to ensure bot is joined to room") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to join with bridge bot") - return EventHandlingResultFailed + return } - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - portalMXID := portal.MXID - if portalMXID != "" { - sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL()) - rejectInvite(ctx, evt, br.Bot, "") - return EventHandlingResultSuccess + didSetPortal := portal.setMXIDToExistingRoom(evt.RoomID) + if resp.Chat.PortalInfo != nil { + portal.UpdateInfo(ctx, resp.Chat.PortalInfo, sourceLogin, nil, time.Time{}) } - err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) - if err != nil { - log.Err(err).Msg("Failed to give permissions to bridge bot") - sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot") - rejectInvite(ctx, evt, br.Bot, "") - return EventHandlingResultSuccess - } - overrideIntent := invitedGhost.Intent - if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { - log.Debug(). - Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). - Msg("Created DM was redirected to another user ID") - _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: "Direct chat redirected to another internal user ID", + if didSetPortal { + // TODO this might become unnecessary if UpdateInfo starts taking care of it + _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{br.Bot.GetMXID()}, }, }, time.Time{}) if err != nil { - log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") + log.Warn().Err(err).Msg("Failed to set service members in room") } - if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot { - overrideIntent = br.Bot - } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { - log.Err(err).Msg("Failed to get ghost of real portal other user ID") - } else { - invitedGhost = otherUserGhost - overrideIntent = otherUserGhost.Intent - } - } - err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ - // We locked it before checking the mxid - RoomCreateAlreadyLocked: true, - - FailIfMXIDSet: true, - ChatInfo: resp.PortalInfo, - ChatInfoSource: sourceLogin, - }) - if err != nil { - log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") - sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work") - return EventHandlingResultSuccess - } - message := "Private chat portal created" - mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) - if ok { - err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + message := "Private chat portal created" + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + hasWarning := false if err != nil { - log.Err(err).Msg("Error in connector newly bridged room handler") - message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + log.Warn().Err(err).Msg("Failed to give power to bot in new DM") + message += "\n\nWarning: failed to promote bot" + hasWarning = true } + mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) + if ok { + err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + if err != nil { + if hasWarning { + message += fmt.Sprintf(", %s", err.Error()) + } else { + message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + } + } + } + sendNotice(ctx, evt, invitedGhost.Intent, message) + } else { + // TODO ensure user is invited even if PortalInfo wasn't provided? + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL()) + rejectInvite(ctx, evt, br.Bot, "") } - sendNotice(ctx, evt, overrideIntent, message) - return EventHandlingResultSuccess } func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWithPower MatrixAPI) error { @@ -280,9 +216,6 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } userLevel := powers.GetUserLevel(userWithPower.GetMXID()) if powers.EnsureUserLevelAs(userWithPower.GetMXID(), br.Bot.GetMXID(), userLevel) { - if userLevel > powers.UsersDefault { - powers.SetUserLevel(userWithPower.GetMXID(), userLevel-1) - } _, err = userWithPower.SendState(ctx, roomID, event.StatePowerLevels, "", &event.Content{ Parsed: powers, }, time.Time{}) @@ -292,3 +225,17 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } return nil } + +func (portal *Portal) setMXIDToExistingRoom(roomID id.RoomID) bool { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + if portal.MXID != "" { + return false + } + portal.MXID = roomID + portal.updateLogger() + portal.Bridge.cacheLock.Lock() + portal.Bridge.portalsByMXID[portal.MXID] = portal + portal.Bridge.cacheLock.Unlock() + return true +} diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index df0c9e4d..04ee8eca 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -12,24 +12,18 @@ import ( "go.mau.fi/util/jsontime" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type MessageStatusEventInfo struct { - RoomID id.RoomID - TransactionID string - SourceEventID id.EventID - NewEventID id.EventID - EventType event.Type - MessageType event.MessageType - Sender id.UserID - ThreadRoot id.EventID - StreamOrder int64 - - IsSourceEventDoublePuppeted bool + RoomID id.RoomID + EventID id.EventID + EventType event.Type + MessageType event.MessageType + Sender id.UserID + ThreadRoot id.EventID } func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { @@ -37,19 +31,13 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { if relatable, ok := evt.Content.Parsed.(event.Relatable); ok { threadRoot = relatable.OptionalGetRelatesTo().GetThreadParent() } - - _, isDoublePuppeted := evt.Content.Raw[appservice.DoublePuppetKey] - return &MessageStatusEventInfo{ - RoomID: evt.RoomID, - TransactionID: evt.Unsigned.TransactionID, - SourceEventID: evt.ID, - EventType: evt.Type, - MessageType: evt.Content.AsMessage().MsgType, - Sender: evt.Sender, - ThreadRoot: threadRoot, - - IsSourceEventDoublePuppeted: isDoublePuppeted, + RoomID: evt.RoomID, + EventID: evt.ID, + EventType: evt.Type, + MessageType: evt.Content.AsMessage().MsgType, + Sender: evt.Sender, + ThreadRoot: threadRoot, } } @@ -161,7 +149,7 @@ func (ms *MessageStatus) ToCheckpoint(evt *MessageStatusEventInfo) *status.Messa } checkpoint := &status.MessageCheckpoint{ RoomID: evt.RoomID, - EventID: evt.SourceEventID, + EventID: evt.EventID, Step: step, Timestamp: jsontime.UnixMilliNow(), Status: ms.checkpointStatus(), @@ -182,12 +170,11 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe content := &event.BeeperMessageStatusEventContent{ RelatesTo: event.RelatesTo{ Type: event.RelReference, - EventID: evt.SourceEventID, + EventID: evt.EventID, }, - TargetTxnID: evt.TransactionID, - Status: ms.Status, - Reason: ms.ErrorReason, - Message: ms.Message, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, } if ms.InternalError != nil { content.InternalError = ms.InternalError.Error() @@ -222,15 +209,15 @@ func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.Messa messagePrefix = "Handling your command panicked" } content := &event.MessageEventContent{ - MsgType: event.MsgNotice, + MsgType: event.MsgText, Body: fmt.Sprintf("\u26a0\ufe0f %s: %s", messagePrefix, msg), RelatesTo: &event.RelatesTo{}, Mentions: &event.Mentions{}, } if evt.ThreadRoot != "" { - content.RelatesTo.SetThread(evt.ThreadRoot, evt.SourceEventID) + content.RelatesTo.SetThread(evt.ThreadRoot, evt.EventID) } else { - content.RelatesTo.SetReplyTo(evt.SourceEventID) + content.RelatesTo.SetReplyTo(evt.EventID) } if evt.Sender != "" { content.Mentions.UserIDs = []id.UserID{evt.Sender} diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index e3a6df70..46f82155 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -43,12 +43,9 @@ type PortalID string // It is also permitted to use a non-empty receiver for group chats if there is a good reason to // segregate them. For example, Telegram's non-supergroups have user-scoped message IDs instead // of chat-scoped IDs, which is easier to manage with segregated rooms. -// -// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true. -// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user. type PortalKey struct { - ID PortalID `json:"portal_id"` - Receiver UserLoginID `json:"portal_receiver,omitempty"` + ID PortalID + Receiver UserLoginID } func (pk PortalKey) IsEmpty() bool { @@ -94,11 +91,6 @@ type MessageID string // Transaction IDs must be unique across users in a room, but don't need to be unique across different rooms. type TransactionID string -// RawTransactionID is a client-generated identifier for a message send operation on the remote network. -// -// Unlike TransactionID, RawTransactionID's are only used for sending and don't have any uniqueness requirements. -type RawTransactionID string - // PartID is the ID of a message part on the remote network (e.g. index of image in album). // // Part IDs are only unique within a message, not globally. diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b706aedb..b68ad0c9 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -8,7 +8,6 @@ package bridgev2 import ( "context" - "encoding/json" "fmt" "strings" "time" @@ -16,9 +15,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" - "go.mau.fi/util/random" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -79,28 +76,8 @@ type EventSender struct { ForceDMUser bool } -func (es EventSender) MarshalZerologObject(evt *zerolog.Event) { - evt.Str("user_id", string(es.Sender)) - if string(es.SenderLogin) != string(es.Sender) { - evt.Str("sender_login", string(es.SenderLogin)) - } - if es.IsFromMe { - evt.Bool("is_from_me", true) - } - if es.ForceDMUser { - evt.Bool("force_dm_user", true) - } -} - type ConvertedMessage struct { - ReplyTo *networkid.MessageOptionalPartID - // Optional additional info about the reply. This is only used when backfilling messages - // on Beeper, where replies may target messages that haven't been bridged yet. - // Standard Matrix servers can't backwards backfill, so these are never used. - ReplyToRoom networkid.PortalKey - ReplyToUser networkid.UserID - ReplyToLogin networkid.UserLoginID - + ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageID Parts []*ConvertedMessagePart Disappear database.DisappearingSetting @@ -119,15 +96,11 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa mediaPart.Content.EnsureHasHTML() mediaPart.Content.Body += "\n\n" + textPart.Content.Body mediaPart.Content.FormattedBody += "

" + textPart.Content.FormattedBody - mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions) - mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...) } else { mediaPart.Content.FileName = mediaPart.Content.Body mediaPart.Content.Body = textPart.Content.Body mediaPart.Content.Format = textPart.Content.Format mediaPart.Content.FormattedBody = textPart.Content.FormattedBody - mediaPart.Content.Mentions = textPart.Content.Mentions - mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews } if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok { metaMerger.CopyFrom(textPart.DBMetadata) @@ -254,16 +227,6 @@ type NetworkConnector interface { // This should generally not do any work, it should just return a LoginProcess that remembers // the user and will execute the requested flow. The actual work should start when [LoginProcess.Start] is called. CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) - - // GetBridgeInfoVersion returns version numbers for bridge info and room capabilities respectively. - // When the versions change, the bridge will automatically resend bridge info to all rooms. - GetBridgeInfoVersion() (info, capabilities int) -} - -type StoppableNetwork interface { - NetworkConnector - // Stop is called when the bridge is stopping, after all network clients have been disconnected. - Stop() } // DirectMediableNetwork is an optional interface that network connectors can implement to support direct media access. @@ -275,22 +238,14 @@ type StoppableNetwork interface { type DirectMediableNetwork interface { NetworkConnector SetUseDirectMedia() - Download(ctx context.Context, mediaID networkid.MediaID, params map[string]string) (mediaproxy.GetMediaResponse, error) + Download(ctx context.Context, mediaID networkid.MediaID) (mediaproxy.GetMediaResponse, error) } -// IdentifierValidatingNetwork is an optional interface that network connectors can implement to validate the shape of user IDs. -// -// This should not perform any checks to see if the user ID actually exists on the network, just that the user ID looks valid. type IdentifierValidatingNetwork interface { NetworkConnector ValidateUserID(id networkid.UserID) bool } -type TransactionIDGeneratingNetwork interface { - NetworkConnector - GenerateTransactionID(userID id.UserID, roomID id.RoomID, eventType event.Type) networkid.RawTransactionID -} - type PortalBridgeInfoFillingNetwork interface { NetworkConnector FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) @@ -318,38 +273,16 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } -type NetworkResettingNetwork interface { - NetworkConnector - // ResetHTTPTransport should recreate the HTTP client used by the bridge. - // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable. - ResetHTTPTransport() - // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections. - // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here. - ResetNetworkConnections() -} - -type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) - type MatrixMessageResponse struct { - DB *database.Message - StreamOrder int64 - // If Pending is set, the bridge will not save the provided message to the database. - // This should only be used if AddPendingToSave has been called. - Pending bool - // If RemovePending is set, the bridge will remove the provided transaction ID from pending messages - // after saving the provided message to the database. This should be used with AddPendingToIgnore. - RemovePending networkid.TransactionID - // An optional function that is called after the message is saved to the database. - // Will not be called if the message is not saved for some reason. - PostSave func(context.Context, *database.Message) + DB *database.Message + + Pending networkid.TransactionID + HandleEcho func(RemoteMessage, *database.Message) (bool, error) } -type OutgoingTimeoutConfig struct { - CheckInterval time.Duration - NoEchoTimeout time.Duration - NoEchoMessage string - NoAckTimeout time.Duration - NoAckMessage string +type FileRestriction struct { + MaxSize int64 + MimeTypes []string } type NetworkGeneralCapabilities struct { @@ -359,16 +292,34 @@ type NetworkGeneralCapabilities struct { // Should the bridge re-request user info on incoming messages even if the ghost already has info? // By default, info is only requested for ghosts with no name, and other updating is left to events. AggressiveUpdateInfo bool - // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message? - // This should be enabled if the network requires each message to be marked as read independently, - // and doesn't automatically do it when sending a message. - ImplicitReadReceipts bool - // If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave]) - // to handle asynchronous message responses, this field can be set to enable - // automatic timeout errors in case the asynchronous response never arrives. - OutgoingMessageTimeouts *OutgoingTimeoutConfig - // Capabilities related to the provisioning API. - Provisioning ProvisioningCapabilities +} + +type NetworkRoomCapabilities struct { + FormattedText bool + UserMentions bool + RoomMentions bool + + LocationMessages bool + Captions bool + MaxTextLength int + MaxCaptionLength int + + Threads bool + Replies bool + Edits bool + EditMaxCount int + EditMaxAge time.Duration + Deletes bool + DeleteMaxAge time.Duration + + DefaultFileRestriction *FileRestriction + Files map[event.MessageType]FileRestriction + + ReadReceipts bool + + Reactions bool + ReactionCount int + AllowedReactions []string } // NetworkAPI is an interface representing a remote network client for a single user login. @@ -378,9 +329,7 @@ type NetworkGeneralCapabilities struct { type NetworkAPI interface { // Connect is called to actually connect to the remote network. // If there's no persistent connection, this may just check access token validity, or even do nothing at all. - // This method isn't allowed to return errors, because any connection errors should be sent - // using the bridge state mechanism (UserLogin.BridgeState.Send) - Connect(ctx context.Context) + Connect(ctx context.Context) error // Disconnect should disconnect from the remote network. // A clean disconnection is preferred, but it should not take too long. Disconnect() @@ -403,7 +352,7 @@ type NetworkAPI interface { // GetCapabilities returns the bridging capabilities in a given room. // This can simply return a static list if the remote network has no per-chat capability differences, // but all calls will include the portal, because some networks do have per-chat differences. - GetCapabilities(ctx context.Context, portal *Portal) *event.RoomFeatures + GetCapabilities(ctx context.Context, portal *Portal) *NetworkRoomCapabilities // HandleMatrixMessage is called when a message is sent from Matrix in an existing portal room. // This function should convert the message as appropriate, send it over to the remote network, @@ -413,30 +362,6 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) } -type ConnectBackgroundParams struct { - // RawData is the raw data in the push that triggered the background connection. - RawData json.RawMessage - // ExtraData is the data returned by [PushParsingNetwork.ParsePushNotification]. - // It's only present for native pushes. Relayed pushes will only have the raw data. - ExtraData any -} - -// BackgroundSyncingNetworkAPI is an optional interface that network connectors can implement to support background resyncs. -type BackgroundSyncingNetworkAPI interface { - NetworkAPI - // ConnectBackground is called in place of Connect for background resyncs. - // The client should connect to the remote network, handle pending messages, and then disconnect. - // This call should block until the entire sync is complete and the client is disconnected. - ConnectBackground(ctx context.Context, params *ConnectBackgroundParams) error -} - -// CredentialExportingNetworkAPI is an optional interface that networks connectors can implement to support export of -// the credentials associated with that login. Credential type is bridge specific. -type CredentialExportingNetworkAPI interface { - NetworkAPI - ExportCredentials(ctx context.Context) any -} - // FetchMessagesParams contains the parameters for a message history pagination request. type FetchMessagesParams struct { // The portal to fetch messages in. Always present. @@ -499,43 +424,6 @@ type BackfillMessage struct { LastThreadMessage networkid.MessageID } -var ( - _ RemoteMessageWithTransactionID = (*BackfillMessage)(nil) - _ RemoteEventWithTimestamp = (*BackfillMessage)(nil) -) - -func (b *BackfillMessage) GetType() RemoteEventType { - return RemoteEventMessage -} - -func (b *BackfillMessage) GetPortalKey() networkid.PortalKey { - panic("GetPortalKey called for BackfillMessage") -} - -func (b *BackfillMessage) AddLogContext(c zerolog.Context) zerolog.Context { - return c -} - -func (b *BackfillMessage) GetSender() EventSender { - return b.Sender -} - -func (b *BackfillMessage) GetID() networkid.MessageID { - return b.ID -} - -func (b *BackfillMessage) GetTransactionID() networkid.TransactionID { - return b.TxnID -} - -func (b *BackfillMessage) GetTimestamp() time.Time { - return b.Timestamp -} - -func (b *BackfillMessage) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { - return b.ConvertedMessage, nil -} - // FetchMessagesResponse contains the response for a message history pagination request. type FetchMessagesResponse struct { // The messages to backfill. Messages should always be sorted in chronological order (oldest to newest). @@ -552,11 +440,6 @@ type FetchMessagesResponse struct { // to mark the messages as read immediately after backfilling. MarkRead bool - // Should the bridge check each message against the database to ensure it's not a duplicate before bridging? - // By default, the bridge will only drop messages that are older than the last bridged message for forward backfills, - // or newer than the first for backward. - AggressiveDeduplication bool - // When HasMore is true, one of the following fields can be set to report backfill progress: // Approximate backfill progress as a number between 0 and 1. @@ -565,26 +448,16 @@ type FetchMessagesResponse struct { ApproxRemainingCount int // Approximate total number of messages in the chat. ApproxTotalCount int - - // An optional function that is called after the backfill batch has been sent. - CompleteCallback func() } // BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. type BackfillingNetworkAPI interface { NetworkAPI - // FetchMessages returns a batch of messages to backfill in a portal room. - // For details on the input and output, see the documentation of [FetchMessagesParams] and [FetchMessagesResponse]. FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) } -// BackfillingNetworkAPIWithLimits is an optional interface that network connectors can implement to customize -// the limit for backwards backfilling tasks. It is recommended to implement this by reading the MaxBatchesOverride -// config field with network-specific keys for different room types. type BackfillingNetworkAPIWithLimits interface { BackfillingNetworkAPI - // GetBackfillMaxBatchCount is called before a backfill task is executed to determine the maximum number of batches - // that should be backfilled. Return values less than 0 are treated as unlimited. GetBackfillMaxBatchCount(ctx context.Context, portal *Portal, task *database.BackfillTask) int } @@ -597,12 +470,6 @@ type EditHandlingNetworkAPI interface { HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error } -type PollHandlingNetworkAPI interface { - NetworkAPI - HandleMatrixPollStart(ctx context.Context, msg *MatrixPollStart) (*MatrixMessageResponse, error) - HandleMatrixPollVote(ctx context.Context, msg *MatrixPollVote) (*MatrixMessageResponse, error) -} - // ReactionHandlingNetworkAPI is an optional interface that network connectors can implement to handle message reactions. type ReactionHandlingNetworkAPI interface { NetworkAPI @@ -638,16 +505,6 @@ type ReadReceiptHandlingNetworkAPI interface { HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error } -// ChatViewingNetworkAPI is an optional interface that network connectors can implement to handle viewing chat status. -type ChatViewingNetworkAPI interface { - NetworkAPI - // HandleMatrixViewingChat is called when the user opens a portal room. - // This will never be called by the standard appservice connector, - // as Matrix doesn't have any standard way of signaling chat open status. - // Clients are expected to call this every 5 seconds. There is no signal for closing a chat. - HandleMatrixViewingChat(ctx context.Context, msg *MatrixViewingChat) error -} - // TypingHandlingNetworkAPI is an optional interface that network connectors can implement to handle typing events. type TypingHandlingNetworkAPI interface { NetworkAPI @@ -702,35 +559,6 @@ type RoomTopicHandlingNetworkAPI interface { HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) } -type DisappearTimerChangingNetworkAPI interface { - NetworkAPI - // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed. - // This method should update the Disappear field of the Portal with the new timer and return true - // if the change was successful. If the change is not successful, then the field should not be updated. - HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error) -} - -// DeleteChatHandlingNetworkAPI is an optional interface that network connectors -// can implement to delete a chat from the remote network. -type DeleteChatHandlingNetworkAPI interface { - NetworkAPI - // HandleMatrixDeleteChat is called when the user explicitly deletes a chat. - HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error -} - -// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors -// can implement to accept message requests from the remote network. -type MessageRequestAcceptingNetworkAPI interface { - NetworkAPI - // HandleMatrixAcceptMessageRequest is called when the user accepts a message request. - HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error -} - -type BeeperAIStreamHandlingNetworkAPI interface { - NetworkAPI - HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error -} - type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -750,27 +578,11 @@ type ResolveIdentifierResponse struct { Chat *CreateChatResponse } -var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10)) - type CreateChatResponse struct { PortalKey networkid.PortalKey // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. Portal *Portal PortalInfo *ChatInfo - // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, - // this field should have the user ID of said different user. - DMRedirectedTo networkid.UserID - - FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant -} - -type CreateChatFailedParticipant struct { - Reason string `json:"reason"` - InviteEventType string `json:"invite_event_type,omitempty"` - InviteContent *event.Content `json:"invite_content,omitempty"` - - UserMXID id.UserID `json:"user_mxid,omitempty"` - DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"` } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. @@ -782,16 +594,6 @@ type IdentifierResolvingNetworkAPI interface { ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*ResolveIdentifierResponse, error) } -// GhostDMCreatingNetworkAPI is an optional extension to IdentifierResolvingNetworkAPI for starting chats with pre-validated user IDs. -type GhostDMCreatingNetworkAPI interface { - IdentifierResolvingNetworkAPI - // CreateChatWithGhost may be called instead of [IdentifierResolvingNetworkAPI.ResolveIdentifier] - // when starting a chat with an internal user identifier that has been pre-validated using - // [IdentifierValidatingNetwork.ValidateUserID]. If this is not implemented, ResolveIdentifier - // will be used instead (by stringifying the ghost ID). - CreateChatWithGhost(ctx context.Context, ghost *Ghost) (*CreateChatResponse, error) -} - // ContactListingNetworkAPI is an optional interface that network connectors can implement to provide the user's contact list. type ContactListingNetworkAPI interface { NetworkAPI @@ -805,83 +607,7 @@ type UserSearchingNetworkAPI interface { type GroupCreatingNetworkAPI interface { IdentifierResolvingNetworkAPI - CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) -} - -type PersonalFilteringCustomizingNetworkAPI interface { - NetworkAPI - CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) -} - -type ProvisioningCapabilities struct { - ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` - GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` -} - -type ResolveIdentifierCapabilities struct { - // Can DMs be created after resolving an identifier? - CreateDM bool `json:"create_dm"` - // Can users be looked up by phone number? - LookupPhone bool `json:"lookup_phone"` - // Can users be looked up by email address? - LookupEmail bool `json:"lookup_email"` - // Can users be looked up by network-specific username? - LookupUsername bool `json:"lookup_username"` - // Can any phone number be contacted without having to validate it via lookup first? - AnyPhone bool `json:"any_phone"` - // Can a contact list be retrieved from the bridge? - ContactList bool `json:"contact_list"` - // Can users be searched by name on the remote network? - Search bool `json:"search"` -} - -type GroupTypeCapabilities struct { - TypeDescription string `json:"type_description"` - - Name GroupFieldCapability `json:"name"` - Username GroupFieldCapability `json:"username"` - Avatar GroupFieldCapability `json:"avatar"` - Topic GroupFieldCapability `json:"topic"` - Disappear GroupFieldCapability `json:"disappear"` - Participants GroupFieldCapability `json:"participants"` - Parent GroupFieldCapability `json:"parent"` -} - -type GroupFieldCapability struct { - // Is setting this field allowed at all in the create request? - // Even if false, the network connector should attempt to set the metadata after group creation, - // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room. - Allowed bool `json:"allowed"` - // Is setting this field mandatory for the creation to succeed? - Required bool `json:"required,omitempty"` - // The minimum/maximum length of the field, if applicable. - // For members, length means the number of members excluding the creator. - MinLength int `json:"min_length,omitempty"` - MaxLength int `json:"max_length,omitempty"` - - // Only for the disappear field: allowed disappearing settings - DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` - - // This can be used to tell provisionutil not to call ValidateUserID on each participant. - // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs. - SkipIdentifierValidation bool `json:"-"` -} - -type GroupCreateParams struct { - Type string `json:"type,omitempty"` - - Username string `json:"username,omitempty"` - // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs - Participants []networkid.UserID `json:"participants,omitempty"` - Parent *networkid.PortalKey `json:"parent,omitempty"` - - Name *event.RoomNameEventContent `json:"name,omitempty"` - Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"` - Topic *event.TopicEventContent `json:"topic,omitempty"` - Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"` - - // An existing room ID to bridge to. If unset, a new room will be created. - RoomID id.RoomID `json:"room_id,omitempty"` + CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error) } type MembershipChangeType struct { @@ -921,15 +647,16 @@ type MatrixMembershipChange struct { MatrixRoomMeta[*event.MemberEventContent] Target GhostOrUserLogin Type MembershipChangeType -} -type MatrixMembershipResult struct { - RedirectTo networkid.UserID + // Deprecated: Use Target instead + TargetGhost *Ghost + // Deprecated: Use Target instead + TargetUserLogin *UserLogin } type MembershipHandlingNetworkAPI interface { NetworkAPI - HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error) + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) } type SinglePowerLevelChange struct { @@ -1018,32 +745,13 @@ type PushConfig struct { Web *WebPushConfig `json:"web,omitempty"` FCM *FCMPushConfig `json:"fcm,omitempty"` APNs *APNsPushConfig `json:"apns,omitempty"` - // If Native is true, it means the network supports registering for pushes - // that are delivered directly to the app without the use of a push relay. - Native bool `json:"native,omitempty"` } -// PushableNetworkAPI is an optional interface that network connectors can implement -// to support waking up the wrapper app using push notifications. type PushableNetworkAPI interface { - NetworkAPI - - // RegisterPushNotifications is called when the wrapper app wants to register a push token with the remote network. RegisterPushNotifications(ctx context.Context, pushType PushType, token string) error - // GetPushConfigs is used to find which types of push notifications the remote network can provide. GetPushConfigs() *PushConfig } -// PushParsingNetwork is an optional interface that network connectors can implement -// to support parsing native push notifications from networks. -type PushParsingNetwork interface { - NetworkConnector - - // ParsePushNotification is called when a native push is received. - // It must return the corresponding user login ID to wake up, plus optionally data to pass to the wakeup call. - ParsePushNotification(ctx context.Context, data json.RawMessage) (networkid.UserLoginID, any, error) -} - type RemoteEventType int func (ret RemoteEventType) String() string { @@ -1115,11 +823,6 @@ type RemoteEvent interface { GetSender() EventSender } -type RemoteEventWithContextMutation interface { - RemoteEvent - MutateContext(ctx context.Context) context.Context -} - type RemoteEventWithUncertainPortalReceiver interface { RemoteEvent PortalReceiverIsUncertain() bool @@ -1130,11 +833,6 @@ type RemotePreHandler interface { PreHandle(ctx context.Context, portal *Portal) } -type RemotePostHandler interface { - RemoteEvent - PostHandle(ctx context.Context, portal *Portal) -} - type RemoteChatInfoChange interface { RemoteEvent GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) @@ -1164,20 +862,11 @@ type RemoteBackfill interface { GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) } -type RemoteDeleteOnlyForMe interface { +type RemoteChatDelete interface { RemoteEvent DeleteOnlyForMe() bool } -type RemoteChatDelete interface { - RemoteDeleteOnlyForMe -} - -type RemoteChatDeleteWithChildren interface { - RemoteChatDelete - DeleteChildren() bool -} - type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool @@ -1297,11 +986,6 @@ type RemoteReadReceipt interface { GetReadUpTo() time.Time } -type RemoteReadReceiptWithStreamOrder interface { - RemoteReadReceipt - GetReadUpToStreamOrder() int64 -} - type RemoteDeliveryReceipt interface { RemoteEvent GetReceiptTargets() []networkid.MessageID @@ -1337,7 +1021,6 @@ type OrigSender struct { RequiresDisambiguation bool DisambiguatedName string FormattedName string - PerMessageProfile event.BeeperPerMessageProfile event.MemberEventContent } @@ -1352,16 +1035,12 @@ type MatrixEventBase[ContentType any] struct { // The original sender user ID. Only present in case the event is being relayed (and Sender is not the same user). OrigSender *OrigSender - - InputTransactionID networkid.RawTransactionID } type MatrixMessage struct { MatrixEventBase[*event.MessageEventContent] ThreadRoot *database.Message ReplyTo *database.Message - - pendingSaves []*outgoingMessage } type MatrixEdit struct { @@ -1369,17 +1048,6 @@ type MatrixEdit struct { EditTarget *database.Message } -type MatrixPollStart struct { - MatrixMessage - Content *event.PollStartEventContent -} - -type MatrixPollVote struct { - MatrixMessage - VoteTo *database.Message - Content *event.PollResponseEventContent -} - type MatrixReaction struct { MatrixEventBase[*event.ReactionEventContent] TargetMessage *database.Message @@ -1410,14 +1078,12 @@ type MatrixMessageRemove struct { type MatrixRoomMeta[ContentType any] struct { MatrixEventBase[ContentType] - PrevContent ContentType - IsStateRequest bool + PrevContent ContentType } type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent] type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent] -type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer] type MatrixReadReceipt struct { Portal *Portal @@ -1432,8 +1098,6 @@ type MatrixReadReceipt struct { LastRead time.Time // The receipt metadata. Receipt event.ReadReceipt - // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt. - Implicit bool } type MatrixTyping struct { @@ -1442,14 +1106,6 @@ type MatrixTyping struct { Type TypingType } -type MatrixViewingChat struct { - // The portal that the user is viewing. This will be nil when the user switches to a chat from a different bridge. - Portal *Portal -} - -type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] -type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] -type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5ba29507..fc047beb 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -14,14 +14,11 @@ import ( "runtime/debug" "strings" "sync" - "sync/atomic" "time" "github.com/rs/zerolog" "go.mau.fi/util/exfmt" - "go.mau.fi/util/exmaps" "go.mau.fi/util/exslices" - "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" "go.mau.fi/util/variationselector" "golang.org/x/exp/maps" @@ -40,9 +37,8 @@ type portalMatrixEvent struct { } type portalRemoteEvent struct { - evt RemoteEvent - source *UserLogin - evtType RemoteEventType + evt RemoteEvent + source *UserLogin } type portalCreateEvent struct { @@ -61,12 +57,9 @@ type portalEvent interface { } type outgoingMessage struct { - db *database.Message - evt *event.Event - ignore bool - handle func(RemoteMessage, *database.Message) (bool, error) - ackedAt time.Time - timeouted bool + db *database.Message + evt *event.Event + handle func(RemoteMessage, *database.Message) (bool, error) } type Portal struct { @@ -79,28 +72,16 @@ type Portal struct { currentlyTyping []id.UserID currentlyTypingLogins map[id.UserID]*UserLogin currentlyTypingLock sync.Mutex - currentlyTypingGhosts *exsync.Set[id.UserID] - outgoingMessages map[networkid.TransactionID]*outgoingMessage + outgoingMessages map[networkid.TransactionID]outgoingMessage outgoingMessagesLock sync.Mutex - lastCapUpdate time.Time + roomCreateLock sync.Mutex - roomCreateLock sync.Mutex - cancelRoomCreate atomic.Pointer[context.CancelFunc] - RoomCreated *exsync.Event - - functionalMembersLock sync.Mutex - functionalMembersCache *event.ElementFunctionalMembersContent - - events chan portalEvent - deleted *exsync.Event - - eventsLock sync.Mutex - eventIdx int + events chan portalEvent } -var PortalEventBuffer = 64 +const PortalEventBuffer = 64 func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, key *networkid.PortalKey) (*Portal, error) { if queryErr != nil { @@ -123,55 +104,34 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Portal: dbPortal, Bridge: br, + events: make(chan portalEvent, PortalEventBuffer), currentlyTypingLogins: make(map[id.UserID]*UserLogin), - currentlyTypingGhosts: exsync.NewSet[id.UserID](), - outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), - - RoomCreated: exsync.NewEvent(), - deleted: exsync.NewEvent(), + outgoingMessages: make(map[networkid.TransactionID]outgoingMessage), } - if portal.MXID != "" { - portal.RoomCreated.Set() - } - // Putting the portal in the cache before it's fully initialized is mildly dangerous, - // but loading the relay user login may depend on it. br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal } var err error - if portal.ParentKey.ID != "" { - portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false) + if portal.ParentID != "" { + portal.Parent, err = br.UnlockedGetPortalByKey(ctx, networkid.PortalKey{ID: portal.ParentID}, false) if err != nil { - delete(br.portalsByKey, portal.PortalKey) - if portal.MXID != "" { - delete(br.portalsByMXID, portal.MXID) - } - return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err) + return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) } } if portal.RelayLoginID != "" { portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID) if err != nil { - delete(br.portalsByKey, portal.PortalKey) - if portal.MXID != "" { - delete(br.portalsByMXID, portal.MXID) - } return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err) } } portal.updateLogger() - if PortalEventBuffer != 0 { - portal.events = make(chan portalEvent, PortalEventBuffer) - go portal.eventLoop() - } + go portal.eventLoop() return portal, nil } func (portal *Portal) updateLogger() { - logWith := portal.Bridge.Log.With(). - Str("portal_id", string(portal.ID)). - Str("portal_receiver", string(portal.Receiver)) + logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID)) if portal.MXID != "" { logWith = logWith.Stringer("portal_mxid", portal.MXID) } @@ -195,20 +155,7 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta return output, nil } -func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) { - if dbPortal == nil { - return nil, nil - } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { - return cached, nil - } else { - return br.loadPortal(ctx, dbPortal, nil, nil) - } -} - func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { - if br.Config.SplitPortals && key.Receiver == "" { - return nil, fmt.Errorf("receiver must always be set when split portals is enabled") - } cached, ok := br.portalsByKey[key] if ok { return cached, nil @@ -234,9 +181,6 @@ func (br *Bridge) FindPortalReceiver(ctx context.Context, id networkid.PortalID, } func (br *Bridge) FindCachedPortalReceiver(id networkid.PortalID, maybeReceiver networkid.UserLoginID) networkid.PortalKey { - if br.Config.SplitPortals { - return networkid.PortalKey{ID: id, Receiver: maybeReceiver} - } br.cacheLock.Lock() defer br.cacheLock.Unlock() portal, ok := br.portalsByKey[networkid.PortalKey{ @@ -294,26 +238,6 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us return br.loadManyPortals(ctx, rows) } -func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) { - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - rows, err := br.DB.Portal.GetChildren(ctx, parent) - if err != nil { - return nil, err - } - return br.loadManyPortals(ctx, rows) -} - -func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID) - if err != nil { - return nil, err - } - return br.loadPortalWithCacheCheck(ctx, dbPortal) -} - func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -323,7 +247,7 @@ func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) ( func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - if key.Receiver == "" || br.Config.SplitPortals { + if key.Receiver == "" { return br.UnlockedGetPortalByKey(ctx, key, true) } cached, ok := br.portalsByKey[key] @@ -338,289 +262,47 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port return br.loadPortal(ctx, db, err, nil) } -func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { - if portal.deleted.IsSet() { - return EventHandlingResultIgnored - } - if PortalEventBuffer == 0 { - portal.eventsLock.Lock() - defer portal.eventsLock.Unlock() - portal.eventIdx++ - return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt) - } else { - if portal.events == nil { - panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey)) - } - select { - case portal.events <- evt: - return EventHandlingResultQueued - case <-portal.deleted.GetChan(): - return EventHandlingResultIgnored - default: - zerolog.Ctx(ctx).Error(). - Str("portal_id", string(portal.ID)). - Msg("Portal event channel is full, queue will block") - for { - select { - case portal.events <- evt: - return EventHandlingResultQueued - case <-time.After(5 * time.Second): - zerolog.Ctx(ctx).Error(). - Str("portal_id", string(portal.ID)). - Msg("Portal event channel is still full") - } - } - } +func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { + select { + case portal.events <- evt: + default: + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is full") } } func (portal *Portal) eventLoop() { - if cfg := portal.Bridge.Network.GetCapabilities().OutgoingMessageTimeouts; cfg != nil { - ctx, cancel := context.WithCancel(portal.Log.WithContext(portal.Bridge.BackgroundCtx)) - go portal.pendingMessageTimeoutLoop(ctx, cfg) - defer cancel() - } - deleteCh := portal.deleted.GetChan() - for i := 0; ; i++ { - select { - case rawEvt := <-portal.events: - if rawEvt == nil { - return - } - if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEventWithDelayLogging(i, rawEvt) - } else { - portal.handleSingleEventWithDelayLogging(i, rawEvt) - } - case <-deleteCh: - return + for rawEvt := range portal.events { + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + portal.handleMatrixEvent(evt.sender, evt.evt) + case *portalRemoteEvent: + portal.handleRemoteEvent(evt.source, evt.evt) + case *portalCreateEvent: + portal.handleCreateEvent(evt) + default: + panic(fmt.Errorf("illegal type %T in eventLoop", evt)) } } } -func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { - ctx := portal.getEventCtxWithLog(rawEvt, idx) - log := zerolog.Ctx(ctx) - doneCh := make(chan struct{}) - var backgrounded atomic.Bool - start := time.Now() - var handleDuration time.Duration - // Note: this will not set the success flag if the handler times out - outerRes = EventHandlingResult{Queued: true} - go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { - outerRes = res - handleDuration = time.Since(start) - close(doneCh) - if backgrounded.Load() { - log.Debug(). - Time("started_at", start). - Stringer("duration", handleDuration). - Msg("Event that took too long finally finished handling") - } - }) - tick := time.NewTicker(30 * time.Second) - _, isCreate := rawEvt.(*portalCreateEvent) - defer tick.Stop() - for i := 0; i < 10; i++ { - select { - case <-doneCh: - if i > 0 { - log.Debug(). - Time("started_at", start). - Stringer("duration", handleDuration). - Msg("Event that took long finished handling") - } - return - case <-tick.C: - log.Warn(). - Time("started_at", start). - Msg("Event handling is taking long") - if isCreate { - // Never background portal creation events - i = 1 - } - } - } - log.Warn(). - Time("started_at", start). - Msg("Event handling is taking too long, continuing in background") - backgrounded.Store(true) - return -} - -type contextKey int - -const ( - contextKeyRemoteEvent contextKey = iota - contextKeyMatrixEvent -) - -func GetMatrixEventFromContext(ctx context.Context) (evt *event.Event) { - evt, _ = ctx.Value(contextKeyMatrixEvent).(*event.Event) - return -} - -func GetRemoteEventFromContext(ctx context.Context) (evt RemoteEvent) { - evt, _ = ctx.Value(contextKeyRemoteEvent).(RemoteEvent) - return -} - -func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { - var logWith zerolog.Context - switch evt := rawEvt.(type) { - case *portalMatrixEvent: - logWith = portal.Log.With().Int("event_loop_index", idx). - Str("action", "handle matrix event"). - Stringer("event_id", evt.evt.ID). - Str("event_type", evt.evt.Type.Type) - if evt.evt.Type.Class != event.EphemeralEventType { - logWith = logWith. - Stringer("event_id", evt.evt.ID). - Stringer("sender", evt.sender.MXID) - } - ctx := portal.Bridge.BackgroundCtx - ctx = context.WithValue(ctx, contextKeyMatrixEvent, evt.evt) - ctx = logWith.Logger().WithContext(ctx) - return ctx - case *portalRemoteEvent: - evt.evtType = evt.evt.GetType() - logWith = portal.Log.With().Int("event_loop_index", idx). - Str("action", "handle remote event"). - Str("source_id", string(evt.source.ID)). - Stringer("bridge_evt_type", evt.evtType) - logWith = evt.evt.AddLogContext(logWith) - if remoteSender := evt.evt.GetSender(); remoteSender.Sender != "" || remoteSender.IsFromMe { - logWith = logWith.Object("remote_sender", remoteSender) - } - if remoteMsg, ok := evt.evt.(RemoteMessage); ok { - if remoteMsgID := remoteMsg.GetID(); remoteMsgID != "" { - logWith = logWith.Str("remote_message_id", string(remoteMsgID)) - } - } - if remoteMsg, ok := evt.evt.(RemoteEventWithTargetMessage); ok { - if targetMsgID := remoteMsg.GetTargetMessage(); targetMsgID != "" { - logWith = logWith.Str("remote_target_message_id", string(targetMsgID)) - } - } - if remoteMsg, ok := evt.evt.(RemoteEventWithStreamOrder); ok { - if remoteStreamOrder := remoteMsg.GetStreamOrder(); remoteStreamOrder != 0 { - logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) - } - } - if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok { - if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() { - logWith = logWith.Time("remote_timestamp", remoteTimestamp) - } - } - ctx := portal.Bridge.BackgroundCtx - ctx = context.WithValue(ctx, contextKeyRemoteEvent, evt.evt) - ctx = logWith.Logger().WithContext(ctx) - if ctxMut, ok := evt.evt.(RemoteEventWithContextMutation); ok { - ctx = ctxMut.MutateContext(ctx) - } - return ctx - case *portalCreateEvent: - return evt.ctx - default: - panic(fmt.Errorf("invalid type %T in getEventCtxWithLog", evt)) - } -} - -func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(res EventHandlingResult)) { - log := zerolog.Ctx(ctx) - var res EventHandlingResult +func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { defer func() { - doneCallback(res) if err := recover(); err != nil { - logEvt := log.Error() - var errorString string + logEvt := zerolog.Ctx(evt.ctx).Error() if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) - errorString = realErr.Error() } else { logEvt = logEvt.Any(zerolog.ErrorFieldName, err) - errorString = fmt.Sprintf("%v", err) } logEvt. Bytes("stack", debug.Stack()). - Msg("Event handling panicked") - switch evt := rawEvt.(type) { - case *portalMatrixEvent: - if evt.evt.ID != "" { - go portal.sendErrorStatus(ctx, evt.evt, ErrPanicInEventHandler) - } - case *portalCreateEvent: - evt.cb(fmt.Errorf("portal creation panicked")) - } - portal.Bridge.TrackAnalytics("", "Bridge Event Handler Panic", map[string]any{ - "error": errorString, - }) + Msg("Portal creation panicked") + evt.cb(fmt.Errorf("portal creation panicked")) } }() - switch evt := rawEvt.(type) { - case *portalMatrixEvent: - isStateRequest := evt.evt.Type == event.BeeperSendState - if isStateRequest { - if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil { - portal.sendErrorStatus(ctx, evt.evt, err) - return - } - } - res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest) - if res.SendMSS { - if res.Error != nil { - portal.sendErrorStatus(ctx, evt.evt, res.Error) - } else { - portal.sendSuccessStatus(ctx, evt.evt, 0, "") - } - } - if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { - portal.revertRoomMeta(ctx, evt.evt) - } - if isStateRequest && res.Success && !res.SkipStateEcho { - portal.sendRoomMeta( - ctx, - evt.sender.DoublePuppet(ctx), - time.UnixMilli(evt.evt.Timestamp), - evt.evt.Type, - evt.evt.GetStateKey(), - evt.evt.Content.Parsed, - false, - evt.evt.Content.Raw, - ) - } - case *portalRemoteEvent: - res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) - case *portalCreateEvent: - err := portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil) - res.Success = err == nil - evt.cb(err) - default: - panic(fmt.Errorf("illegal type %T in eventLoop", evt)) - } -} - -func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error { - content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent) - if !ok { - return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - } - evt.Content = content.Content - evt.StateKey = &content.StateKey - evt.Type = event.Type{Type: content.Type, Class: event.StateEventType} - _ = evt.Content.ParseRaw(evt.Type) - mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) - if !ok { - return fmt.Errorf("matrix connector doesn't support fetching state") - } - prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) - if err != nil && !errors.Is(err, mautrix.MNotFound) { - return fmt.Errorf("failed to get prev event: %w", err) - } else if prevEvt != nil { - evt.Unsigned.PrevContent = &prevEvt.Content - evt.Unsigned.PrevSender = prevEvt.Sender - } - return nil + evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { @@ -629,15 +311,12 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR if err != nil { return nil, nil, err } - if login == nil { - return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn) - } else if !login.Client.IsLoggedIn() { - return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn) - } else if login.UserMXID != user.MXID { + if login.UserMXID != user.MXID { if allowRelay && portal.Relay != nil { return nil, nil, nil } - return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID) + // TODO different error for this case? + return nil, nil, ErrNotLoggedIn } up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) return login, up, err @@ -648,9 +327,9 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } portal.Bridge.cacheLock.Lock() defer portal.Bridge.cacheLock.Unlock() - for _, up := range logins { + for i, up := range logins { login, ok := user.logins[up.LoginID] - if ok && login.Client != nil && login.Client.IsLoggedIn() { + if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { return login, up, nil } } @@ -666,7 +345,7 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR firstLogin = login break } - if firstLogin != nil && firstLogin.Client.IsLoggedIn() { + if firstLogin != nil { zerolog.Ctx(ctx).Warn(). Str("chosen_login_id", string(firstLogin.ID)). Msg("No usable user portal rows found, returning random login") @@ -676,13 +355,8 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } } -func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { - info := StatusEventInfoFromEvent(evt) - info.StreamOrder = streamOrder - if newEventID != evt.ID { - info.NewEventID = newEventID - } - portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, info) +func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event) { + portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, StatusEventInfoFromEvent(evt)) } func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err error) { @@ -718,45 +392,54 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, return false } -var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} - -func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { - log := zerolog.Ctx(ctx) +func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { + log := portal.Log.With(). + Str("action", "handle matrix event"). + Str("event_type", evt.Type.Type). + Logger() + ctx := log.WithContext(context.TODO()) + defer func() { + if err := recover(); err != nil { + logEvt := log.Error() + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt. + Bytes("stack", debug.Stack()). + Msg("Matrix event handler panicked") + if evt.ID != "" { + go portal.sendErrorStatus(ctx, evt, ErrPanicInEventHandler) + } + } + }() if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { case event.EphemeralEventReceipt: - return portal.handleMatrixReceipts(ctx, evt) + portal.handleMatrixReceipts(ctx, evt) case event.EphemeralEventTyping: - return portal.handleMatrixTyping(ctx, evt) - case event.BeeperEphemeralEventAIStream: - return portal.handleMatrixAIStream(ctx, sender, evt) - default: - return EventHandlingResultIgnored + portal.handleMatrixTyping(ctx, evt) } + return } - if evt.Type == event.StateTombstone { - // Tombstones aren't bridged so they don't need a login - return portal.handleMatrixTombstone(ctx, evt) - } - login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Stringer("event_id", evt.ID). + Stringer("sender", sender.MXID) + }) + login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") if errors.Is(err, ErrNotLoggedIn) { - shouldSendNotice := evt.Content.AsMessage().MsgType != event.MsgNotice - return EventHandlingResultFailed.WithMSSError( - WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(shouldSendNotice), - ) + portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(true)) } else { - return EventHandlingResultFailed.WithMSSError( - WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true), - ) + portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) } + return } var origSender *OrigSender if login == nil { - if isStateRequest { - return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest) - } login = portal.Relay origSender = &OrigSender{ User: sender, @@ -778,88 +461,48 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } else { origSender.DisambiguatedName = sender.MXID.String() } - msg := evt.Content.AsMessage() - if msg != nil && msg.BeeperPerMessageProfile != nil && msg.BeeperPerMessageProfile.Displayname != "" { - pmp := msg.BeeperPerMessageProfile - origSender.PerMessageProfile = *pmp - roomPLs, err := portal.Bridge.Matrix.GetPowerLevels(ctx, portal.MXID) - if err != nil { - log.Warn().Err(err).Msg("Failed to get power levels to check relay profile") - } - if roomPLs != nil && - roomPLs.GetUserLevel(sender.MXID) >= roomPLs.GetEventLevel(fakePerMessageProfileEventType) && - !portal.checkConfusableName(ctx, sender.MXID, pmp.Displayname) { - origSender.DisambiguatedName = pmp.Displayname - origSender.RequiresDisambiguation = false - } else { - origSender.DisambiguatedName = fmt.Sprintf("%s via %s", pmp.Displayname, origSender.DisambiguatedName) - } - } - origSender.FormattedName = portal.Bridge.Config.Relay.FormatName(origSender) } - // Copy logger because many of the handlers will use UpdateContext - ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) - - if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() { - rrLog := log.With().Str("subaction", "implicit read receipt").Logger() - rrCtx := rrLog.WithContext(ctx) - rrLog.Debug().Msg("Sending implicit read receipt for event") - evtTS := time.UnixMilli(evt.Timestamp) - portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{ - Portal: portal, - EventID: evt.ID, - Implicit: true, - ReadUpTo: evtTS, - Receipt: event.ReadReceipt{Timestamp: evtTS}, - }, userPortal) - } - + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("login_id", string(login.ID)) + }) switch evt.Type { - case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: - return portal.handleMatrixMessage(ctx, login, origSender, evt) + case event.EventMessage, event.EventSticker: + portal.handleMatrixMessage(ctx, login, origSender, evt) case event.EventReaction: if origSender != nil { log.Debug().Msg("Ignoring reaction event from relayed user") - return EventHandlingResultIgnored.WithMSSError(ErrIgnoringReactionFromRelayedUser) + portal.sendErrorStatus(ctx, evt, ErrIgnoringReactionFromRelayedUser) + return } - return portal.handleMatrixReaction(ctx, login, evt) + portal.handleMatrixReaction(ctx, login, evt) case event.EventRedaction: - return portal.handleMatrixRedaction(ctx, login, origSender, evt) + portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) - case event.StateBeeperDisappearingTimer: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) + handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateEncryption: // TODO? - return EventHandlingResultIgnored case event.AccountDataMarkedUnread: - return handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) + handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) case event.AccountDataRoomTags: - return handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) + handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) case event.AccountDataBeeperMute: - return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) + handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: - return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest) + portal.handleMatrixMembership(ctx, login, origSender, evt) case event.StatePowerLevels: - return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) - case event.BeeperDeleteChat: - return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) - case event.BeeperAcceptMessageRequest: - return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt) - default: - return EventHandlingResultIgnored + portal.handleMatrixPowerLevels(ctx, login, origSender, evt) } } -func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) { content, ok := evt.Content.Parsed.(*event.ReceiptEventContent) if !ok { - return EventHandlingResultFailed + return } for evtID, receipts := range *content { readReceipts, ok := receipts[event.ReceiptTypeRead] @@ -869,14 +512,12 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event for userID, receipt := range readReceipts { sender, err := portal.Bridge.GetUserByMXID(ctx, userID) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt") - return EventHandlingResultFailed.WithError(err) + // TODO log + return } portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) } } - // TODO actual status - return EventHandlingResultSuccess } func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { @@ -908,10 +549,15 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e EventID: eventID, Receipt: receipt, } + if userPortal == nil { + userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) + } else { + evt.LastRead = userPortal.LastRead + userPortal = userPortal.CopyWithoutValues() + } evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) if err != nil { log.Err(err).Msg("Failed to get exact message from database") - evt.ReadUpTo = receipt.Timestamp } else if evt.ExactMessage != nil { log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp) @@ -920,46 +566,27 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e } else { evt.ReadUpTo = receipt.Timestamp } - portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) -} - -func (portal *Portal) callReadReceiptHandler( - ctx context.Context, - login *UserLogin, - rrClient ReadReceiptHandlingNetworkAPI, - evt *MatrixReadReceipt, - userPortal *database.UserPortal, -) { - if rrClient == nil { - var ok bool - rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI) - if !ok { - return - } - } - if userPortal == nil { - userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) - } else { - evt.LastRead = userPortal.LastRead - userPortal = userPortal.CopyWithoutValues() - } - err := rrClient.HandleMatrixReadReceipt(ctx, evt) + err = rrClient.HandleMatrixReadReceipt(ctx, evt) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt") + log.Err(err).Msg("Failed to handle read receipt") return } - userPortal.LastRead = evt.ReadUpTo + if evt.ExactMessage != nil { + userPortal.LastRead = evt.ExactMessage.Timestamp + } else { + userPortal.LastRead = receipt.Timestamp + } err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") + log.Err(err).Msg("Failed to save user portal metadata") } - portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo) + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } -func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) { content, ok := evt.Content.Parsed.(*event.TypingEventContent) if !ok { - return EventHandlingResultFailed + return } portal.currentlyTypingLock.Lock() defer portal.currentlyTypingLock.Unlock() @@ -970,52 +597,6 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) portal.sendTypings(ctx, stoppedTyping, false) portal.sendTypings(ctx, startedTyping, true) portal.currentlyTyping = content.UserIDs - // TODO actual status - return EventHandlingResultSuccess -} - -func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { - log := zerolog.Ctx(ctx) - if sender == nil { - log.Error().Msg("Missing sender for Matrix AI stream event") - return EventHandlingResultIgnored - } - login, _, err := portal.FindPreferredLogin(ctx, sender, true) - if err != nil { - log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event") - return EventHandlingResultFailed.WithMSSError(err) - } - var origSender *OrigSender - if login == nil { - if portal.Relay == nil { - return EventHandlingResultIgnored - } - login = portal.Relay - origSender = &OrigSender{ - User: sender, - UserID: sender.MXID, - } - } - content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI) - if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported) - } - err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{ - Event: evt, - Content: content, - Portal: portal, - OrigSender: origSender, - }) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix AI stream event") - return EventHandlingResultFailed.WithMSSError(err) - } - return EventHandlingResultSuccess.WithMSS() } func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { @@ -1106,156 +687,78 @@ func (portal *Portal) periodicTypingUpdater() { } } -func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { +func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { switch content.MsgType { case event.MsgText, event.MsgNotice, event.MsgEmote: // No checks for now, message length is safer to check after conversion inside connector case event.MsgLocation: - if caps.LocationMessage.Reject() { - return ErrLocationMessagesNotAllowed + if !caps.LocationMessages { + portal.sendErrorStatus(ctx, evt, ErrLocationMessagesNotAllowed) + return false } - case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile, event.CapMsgSticker: - capMsgType := content.GetCapMsgType() - feat, ok := caps.File[capMsgType] - if !ok { - return ErrUnsupportedMessageType - } - if content.MsgType != event.CapMsgSticker && - content.FileName != "" && - content.Body != content.FileName && - feat.Caption.Reject() { - return ErrCaptionsNotAllowed - } - if content.Info != nil { - dur := time.Duration(content.Info.Duration) * time.Millisecond - if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { - if capMsgType == event.CapMsgVoice { - return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration)) - } - return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) - } - if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { - return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024) - } - if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() { - return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType) + case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile: + if content.FileName != "" && content.Body != content.FileName { + if !caps.Captions { + portal.sendErrorStatus(ctx, evt, ErrCaptionsNotAllowed) + return false } } - fallthrough default: } - return nil + return true } -func (portal *Portal) parseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { - if origSender != nil || !strings.HasPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix) { - return "" - } - return networkid.RawTransactionID(strings.TrimPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix)) -} - -func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) - var relatesTo *event.RelatesTo - var msgContent *event.MessageEventContent - var pollContent *event.PollStartEventContent - var pollResponseContent *event.PollResponseEventContent - var ok bool - if evt.Type == event.EventUnstablePollStart { - pollContent, ok = evt.Content.Parsed.(*event.PollStartEventContent) - relatesTo = pollContent.RelatesTo - } else if evt.Type == event.EventUnstablePollResponse { - pollResponseContent, ok = evt.Content.Parsed.(*event.PollResponseEventContent) - relatesTo = &pollResponseContent.RelatesTo - } else { - msgContent, ok = evt.Content.Parsed.(*event.MessageEventContent) - relatesTo = msgContent.RelatesTo - if evt.Type == event.EventSticker { - msgContent.MsgType = event.CapMsgSticker - } - if msgContent.MsgType == event.MsgNotice && !portal.Bridge.Config.BridgeNotices { - return EventHandlingResultIgnored.WithMSSError(ErrIgnoringMNotice) - } - } + content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed. - WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } caps := sender.Client.GetCapabilities(ctx, portal) - if relatesTo.GetReplaceID() != "" { - if msgContent == nil { - log.Warn().Msg("Ignoring edit of poll") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w of polls", ErrEditsNotSupported)) - } - return portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) + if content.RelatesTo.GetReplaceID() != "" { + portal.handleMatrixEdit(ctx, sender, origSender, evt, content, caps) + return } var err error if origSender != nil { - if msgContent == nil { - log.Debug().Msg("Ignoring poll event from relayed user") - return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) - } - if !caps.PerMessageProfileRelay { - msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) - if err != nil { - log.Err(err).Msg("Failed to format message for relaying") - return EventHandlingResultFailed.WithMSSError(err) - } + content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + portal.sendErrorStatus(ctx, evt, err) + return } } - if msgContent != nil { - if err = portal.checkMessageContentCaps(caps, msgContent); err != nil { - return EventHandlingResultFailed.WithMSSError(err) - } - } else if pollResponseContent != nil || pollContent != nil { - if _, ok = sender.Client.(PollHandlingNetworkAPI); !ok { - log.Debug().Msg("Ignoring poll event as network connector doesn't implement PollHandlingNetworkAPI") - return EventHandlingResultIgnored.WithMSSError(ErrPollsNotSupported) - } + if !portal.checkMessageContentCaps(ctx, caps, content, evt) { + return } - var threadRoot, replyTo, voteTo *database.Message - if evt.Type == event.EventUnstablePollResponse { - voteTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, relatesTo.GetReferenceID()) - if err != nil { - log.Err(err).Msg("Failed to get poll target message from database") - // TODO send status - return EventHandlingResultFailed - } else if voteTo == nil { - log.Warn().Stringer("vote_to_id", relatesTo.GetReferenceID()).Msg("Poll target message not found") - // TODO send status - return EventHandlingResultFailed - } - } + var threadRoot, replyTo *database.Message var replyToID id.EventID - threadRootID := relatesTo.GetThreadParent() - if caps.Thread.Partial() { - replyToID = relatesTo.GetNonFallbackReplyTo() - if threadRootID != "" { - threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) - if err != nil { - log.Err(err).Msg("Failed to get thread root message from database") - } else if threadRoot == nil { - log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") - } - } + if caps.Threads { + replyToID = content.RelatesTo.GetNonFallbackReplyTo() } else { - replyToID = relatesTo.GetReplyTo() + replyToID = content.RelatesTo.GetReplyTo() } - if replyToID != "" && (caps.Reply.Partial() || caps.Thread.Partial()) { + threadRootID := content.RelatesTo.GetThreadParent() + if caps.Threads && threadRootID != "" { + threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } + } + if replyToID != "" && (caps.Replies || caps.Threads) { replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") - } else if replyTo == nil { - log.Warn().Stringer("reply_to_id", replyToID).Msg("Reply target message not found") } else { // Support replying to threads from non-thread-capable clients. // The fallback happens if the message is not a Matrix thread and either // * the replied-to message is in a thread, or // * the network only supports threads (assume the user wants to start a new thread) - if caps.Thread.Partial() && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Reply.Partial()) { + if caps.Threads && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Replies) { threadRootRemoteID := replyTo.ThreadRoot if threadRootRemoteID == "" { threadRootRemoteID = replyTo.ID @@ -1265,239 +768,84 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin log.Err(err).Msg("Failed to get thread root message from database (via reply fallback)") } } - if !caps.Reply.Partial() { + if !caps.Replies { replyTo = nil } } } - var messageTimer *event.BeeperDisappearingTimer - if msgContent != nil { - messageTimer = msgContent.BeeperDisappearingTimer - } - if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer { - log.Warn(). - Any("event_timer", messageTimer). - Any("portal_timer", portal.Disappear.ToEventContent()). - Msg("Mismatching disappearing timer in event") - } - wrappedMsgEvt := &MatrixMessage{ + resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ Event: evt, - Content: msgContent, + Content: content, OrigSender: origSender, Portal: portal, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, ThreadRoot: threadRoot, ReplyTo: replyTo, - } - if portal.Bridge.Config.DeduplicateMatrixMessages { - if part, err := portal.Bridge.DB.Message.GetPartByTxnID(ctx, portal.Receiver, evt.ID, wrappedMsgEvt.InputTransactionID); err != nil { - log.Err(err).Msg("Failed to check db if message is already sent") - } else if part != nil { - log.Debug(). - Stringer("message_mxid", part.MXID). - Stringer("input_event_id", evt.ID). - Msg("Message already sent, ignoring") - return EventHandlingResultIgnored - } - } - - err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps) - if err != nil { - log.Warn().Err(err).Msg("Failed to auto-accept message request on message") - // TODO stop processing? - } - - var resp *MatrixMessageResponse - if msgContent != nil { - resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) - } else if pollContent != nil { - resp, err = sender.Client.(PollHandlingNetworkAPI).HandleMatrixPollStart(ctx, &MatrixPollStart{ - MatrixMessage: *wrappedMsgEvt, - Content: pollContent, - }) - } else if pollResponseContent != nil { - resp, err = sender.Client.(PollHandlingNetworkAPI).HandleMatrixPollVote(ctx, &MatrixPollVote{ - MatrixMessage: *wrappedMsgEvt, - VoteTo: voteTo, - Content: pollResponseContent, - }) - } else { - log.Error().Msg("Failed to handle Matrix message: all contents are nil?") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("all contents are nil")) - } + }) if err != nil { log.Err(err).Msg("Failed to handle Matrix message") - return EventHandlingResultFailed.WithMSSError(err) - } - message := wrappedMsgEvt.fillDBMessage(resp.DB) - if resp.Pending { - for _, save := range wrappedMsgEvt.pendingSaves { - save.ackedAt = time.Now() - } - } else { - if resp.DB == nil { - log.Error().Msg("Network connector didn't return a message to save") - } else { - if portal.Bridge.Config.OutgoingMessageReID { - message.MXID = portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, message.ID, message.PartID) - } - // Hack to ensure the ghost row exists - // TODO move to better place (like login) - portal.Bridge.GetGhostByID(ctx, message.SenderID) - err = portal.Bridge.DB.Message.Insert(ctx, message) - if err != nil { - log.Err(err).Msg("Failed to save message to database") - } else if resp.PostSave != nil { - resp.PostSave(ctx, message) - } - if resp.RemovePending != "" { - portal.outgoingMessagesLock.Lock() - delete(portal.outgoingMessages, resp.RemovePending) - portal.outgoingMessagesLock.Unlock() - } - } - portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) - } - ds := portal.Disappear - if messageTimer != nil { - ds = database.DisappearingSettingFromEvent(messageTimer) - } - if ds.Type != event.DisappearingTypeNone { - go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ - RoomID: portal.MXID, - EventID: message.MXID, - Timestamp: message.Timestamp, - DisappearingSetting: ds.StartingAt(message.Timestamp), - }) - } - if resp.Pending { - // Not exactly queued, but not finished either - return EventHandlingResultQueued - } - return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder) -} - -// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. -// -// This should be used when the network connector will return the real message ID from HandleMatrixMessage. -// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit -// after saving to database. -// -// See also: [MatrixMessage.AddPendingToSave] -func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { - evt.Portal.outgoingMessagesLock.Lock() - evt.Portal.outgoingMessages[txnID] = &outgoingMessage{ - ignore: true, - } - evt.Portal.outgoingMessagesLock.Unlock() -} - -// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered. -// -// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage, -// i.e. when the network connector does not know the message ID at the end of the handler. -// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database. -// -// The provided function will be called when the message is encountered. -func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) { - pending := &outgoingMessage{ - db: evt.fillDBMessage(message), - evt: evt.Event, - handle: handleEcho, - } - evt.Portal.outgoingMessagesLock.Lock() - evt.Portal.outgoingMessages[txnID] = pending - evt.pendingSaves = append(evt.pendingSaves, pending) - evt.Portal.outgoingMessagesLock.Unlock() -} - -// RemovePending removes a transaction ID from the list of pending messages. -// This should only be called if sending the message fails. -func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) { - evt.Portal.outgoingMessagesLock.Lock() - pendingSave := evt.Portal.outgoingMessages[txnID] - if pendingSave != nil { - evt.pendingSaves = slices.DeleteFunc(evt.pendingSaves, func(save *outgoingMessage) bool { - return save == pendingSave - }) - } - delete(evt.Portal.outgoingMessages, txnID) - evt.Portal.outgoingMessagesLock.Unlock() -} - -func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message { - if message == nil { - message = &database.Message{} + portal.sendErrorStatus(ctx, evt, err) + return } + message := resp.DB if message.MXID == "" { - message.MXID = evt.Event.ID + message.MXID = evt.ID } if message.Room.ID == "" { - message.Room = evt.Portal.PortalKey + message.Room = portal.PortalKey } if message.Timestamp.IsZero() { - message.Timestamp = time.UnixMilli(evt.Event.Timestamp) + message.Timestamp = time.UnixMilli(evt.Timestamp) } - if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil { - message.ReplyTo.MessageID = evt.ReplyTo.ID - message.ReplyTo.PartID = &evt.ReplyTo.PartID + if message.ReplyTo.MessageID == "" && replyTo != nil { + message.ReplyTo.MessageID = replyTo.ID + message.ReplyTo.PartID = &replyTo.PartID } - if message.ThreadRoot == "" && evt.ThreadRoot != nil { - message.ThreadRoot = evt.ThreadRoot.ID - if evt.ThreadRoot.ThreadRoot != "" { - message.ThreadRoot = evt.ThreadRoot.ThreadRoot + if message.ThreadRoot == "" && threadRoot != nil { + message.ThreadRoot = threadRoot.ID + if threadRoot.ThreadRoot != "" { + message.ThreadRoot = threadRoot.ThreadRoot } } if message.SenderMXID == "" { - message.SenderMXID = evt.Event.Sender + message.SenderMXID = evt.Sender } - if message.SendTxnID != "" { - message.SendTxnID = evt.InputTransactionID - } - return message -} - -func (portal *Portal) pendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { - ticker := time.NewTicker(cfg.CheckInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - portal.checkPendingMessages(ctx, cfg) - case <-ctx.Done(): - return + if resp.Pending != "" { + // TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request + // (for now this is fine because incoming messages will wait in the queue for this function to return) + portal.outgoingMessagesLock.Lock() + portal.outgoingMessages[resp.Pending] = outgoingMessage{ + db: message, + evt: evt, + handle: resp.HandleEcho, } - } -} - -func (portal *Portal) checkPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { - portal.outgoingMessagesLock.Lock() - defer portal.outgoingMessagesLock.Unlock() - for _, msg := range portal.outgoingMessages { - if msg.evt != nil && !msg.timeouted { - if cfg.NoEchoTimeout > 0 && !msg.ackedAt.IsZero() && time.Since(msg.ackedAt) > cfg.NoEchoTimeout { - msg.timeouted = true - portal.sendErrorStatus(ctx, msg.evt, ErrRemoteEchoTimeout.WithMessage(cfg.NoEchoMessage)) - } else if cfg.NoAckTimeout > 0 && time.Since(msg.db.Timestamp) > cfg.NoAckTimeout { - msg.timeouted = true - portal.sendErrorStatus(ctx, msg.evt, ErrRemoteAckTimeout.WithMessage(cfg.NoAckMessage)) - } + portal.outgoingMessagesLock.Unlock() + } else { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") } + portal.sendSuccessStatus(ctx, evt) + } + if portal.Disappear.Type != database.DisappearingTypeNone { + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: message.MXID, + DisappearingSetting: database.DisappearingSetting{ + Type: portal.Disappear.Type, + Timer: portal.Disappear.Timer, + DisappearAt: message.Timestamp.Add(portal.Disappear.Timer), + }, + }) } } -func (portal *Portal) handleMatrixEdit( - ctx context.Context, - sender *UserLogin, - origSender *OrigSender, - evt *event.Event, - content *event.MessageEventContent, - caps *event.RoomFeatures, -) EventHandlingResult { +func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -1505,40 +853,44 @@ func (portal *Portal) handleMatrixEdit( }) if content.NewContent != nil { content = content.NewContent - if evt.Type == event.EventSticker { - content.MsgType = event.CapMsgSticker - } } if origSender != nil { var err error content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) if err != nil { log.Err(err).Msg("Failed to format message for relaying") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } } editingAPI, ok := sender.Client.(EditHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") - return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupported) - } else if !caps.Edit.Partial() { + portal.sendErrorStatus(ctx, evt, ErrEditsNotSupported) + return + } else if !caps.Edits { log.Debug().Msg("Ignoring edit as room doesn't support edits") - return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupportedInPortal) - } else if err := portal.checkMessageContentCaps(caps, content); err != nil { - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, ErrEditsNotSupportedInPortal) + return + } else if !portal.checkMessageContentCaps(ctx, caps, content, evt) { + return } editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) if err != nil { log.Err(err).Msg("Failed to get edit target message from database") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) + return } else if editTarget == nil { log.Warn().Msg("Edit target message not found in database") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("edit %w", ErrTargetMessageNotFound)) - } else if caps.EditMaxAge != nil && caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { - return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooOld) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) + return + } else if caps.EditMaxAge > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge { + portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) + return } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { - return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooManyEdits) + portal.sendErrorStatus(ctx, evt, ErrEditTargetTooManyEdits) + return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("edit_target_remote_id", string(editTarget.ID)) @@ -1549,35 +901,34 @@ func (portal *Portal) handleMatrixEdit( Content: content, OrigSender: origSender, Portal: portal, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, EditTarget: editTarget, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix edit") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } err = portal.Bridge.DB.Message.Update(ctx, editTarget) if err != nil { log.Err(err).Msg("Failed to save message to database after editing") } - // TODO allow returning stream order from HandleMatrixEdit - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultSuccess + portal.sendSuccessStatus(ctx, evt) } -func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { log := zerolog.Ctx(ctx) reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring reaction as network connector doesn't implement ReactionHandlingNetworkAPI") - return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) + portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) + return } content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("reaction_target_mxid", content.RelatesTo.EventID) @@ -1585,16 +936,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi reactionTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.RelatesTo.EventID) if err != nil { log.Err(err).Msg("Failed to get reaction target message from database") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) + return } else if reactionTarget == nil { log.Warn().Msg("Reaction target message not found in database") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) - } - caps := sender.Client.GetCapabilities(ctx, portal) - err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps) - if err != nil { - log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction") - // TODO stop processing? + portal.sendErrorStatus(ctx, evt, fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) + return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) @@ -1604,64 +951,42 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi Event: evt, Content: content, Portal: portal, - - InputTransactionID: portal.parseInputTransactionID(nil, evt), }, TargetMessage: reactionTarget, } preResp, err := reactingAPI.PreHandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to pre-handle Matrix reaction") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } - var deterministicID id.EventID - if portal.Bridge.Config.OutgoingMessageReID { - deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) - } - defer func() { - // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction - if handleRes.Success { - portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - } - }() - removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { - if !handleRes.Success { + existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) + if err != nil { + log.Err(err).Msg("Failed to check if reaction is a duplicate") + return + } else if existing != nil { + if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { + log.Debug().Msg("Ignoring duplicate reaction") + portal.sendSuccessStatus(ctx, evt) return } - _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + react.ReactionToOverride = existing + _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ - Redacts: oldReact.MXID, + Redacts: existing.MXID, }, }, nil) if err != nil { log.Err(err).Msg("Failed to remove old reaction") } - if deleteDB { - err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact) - if err != nil { - log.Err(err).Msg("Failed to delete old reaction from database") - } - } - } - existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) - if err != nil { - log.Err(err).Msg("Failed to check if reaction is a duplicate") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to check for existing reaction: %w", ErrDatabaseError, err)) - } else if existing != nil { - if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { - log.Debug().Msg("Ignoring duplicate reaction") - portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return EventHandlingResultIgnored.WithEventID(deterministicID) - } - react.ReactionToOverride = existing - defer removeOutdatedReaction(existing, false) } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { - allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID) + allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, reactionTarget.ID, preResp.SenderID) if err != nil { log.Err(err).Msg("Failed to get all reactions to message by sender") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) + return } if len(allReactions) < preResp.MaxReactions { react.ExistingReactionsToKeep = allReactions @@ -1669,21 +994,26 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi // Keep n-1 previous reactions and remove the rest react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { - if existing != nil && oldReaction.EmojiID == existing.EmojiID { - // Don't double-delete on networks that only allow one emoji - continue + _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: oldReaction.MXID, + }, + }, nil) + if err != nil { + log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") + } + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction) + if err != nil { + log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded") } - // Intentionally defer in a loop, there won't be that many items, - // and we want all of them to be done after this function completes successfully - //goland:noinspection GoDeferInLoop - defer removeOutdatedReaction(oldReaction, true) } } } dbReaction, err := reactingAPI.HandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } if dbReaction == nil { dbReaction = &database.Reaction{} @@ -1696,9 +1026,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi dbReaction.MessageID = reactionTarget.ID dbReaction.MessagePartID = reactionTarget.PartID } - if deterministicID != "" { - dbReaction.MXID = deterministicID - } else if dbReaction.MXID == "" { + if dbReaction.MXID == "" { dbReaction.MXID = evt.ID } if dbReaction.Timestamp.IsZero() { @@ -1721,7 +1049,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - return EventHandlingResultSuccess.WithEventID(deterministicID) + portal.sendSuccessStatus(ctx, evt) } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1730,53 +1058,35 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( sender *UserLogin, origSender *OrigSender, evt *event.Event, - isStateRequest bool, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), -) EventHandlingResult { - if evt.StateKey == nil || *evt.StateKey != "" { - return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) - } - //caps := sender.Client.GetCapabilities(ctx, portal) - //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported { - // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed)) - //} +) { api, ok := sender.Client.(APIType) if !ok { - return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type)) + portal.sendErrorStatus(ctx, evt, ErrRoomMetadataNotSupported) + return } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } switch typedContent := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: if typedContent.Name == portal.Name { - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultIgnored + portal.sendSuccessStatus(ctx, evt) + return } case *event.TopicEventContent: if typedContent.Topic == portal.Topic { - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultIgnored + portal.sendSuccessStatus(ctx, evt) + return } case *event.RoomAvatarEventContent: if typedContent.URL == portal.AvatarMXC { - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultIgnored - } - case *event.BeeperDisappearingTimer: - if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 { - typedContent.Type = event.DisappearingTypeNone - typedContent.Timer.Duration = 0 - } - if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer { - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultIgnored - } - if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { - return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) + portal.sendSuccessStatus(ctx, evt) + return } } var prevContent ContentType @@ -1791,41 +1101,37 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( Content: content, Portal: portal, OrigSender: origSender, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - IsStateRequest: isStateRequest, - PrevContent: prevContent, + PrevContent: prevContent, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } if changed { - if evt.Type != event.StateBeeperDisappearingTimer { - portal.UpdateBridgeInfo(ctx) - } + portal.UpdateBridgeInfo(ctx) err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal after updating room metadata") } } - return EventHandlingResultSuccess.WithMSS() + portal.sendSuccessStatus(ctx, evt) } func handleMatrixAccountData[APIType any, ContentType any]( portal *Portal, ctx context.Context, sender *UserLogin, evt *event.Event, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) error, -) EventHandlingResult { +) { api, ok := sender.Client.(APIType) if !ok { - return EventHandlingResultIgnored + return } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } var prevContent ContentType if evt.Unsigned.PrevContent != nil { @@ -1843,9 +1149,7 @@ func handleMatrixAccountData[APIType any, ContentType any]( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room account data") - return EventHandlingResultFailed.WithError(err) } - return EventHandlingResultSuccess } func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { @@ -1865,144 +1169,18 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos } } -func (portal *Portal) handleMatrixAcceptMessageRequest( - ctx context.Context, - sender *UserLogin, - origSender *OrigSender, - evt *event.Event, -) EventHandlingResult { - if origSender != nil { - return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser) - } - log := zerolog.Ctx(ctx) - content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) - if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) - } - err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ - Event: evt, - Content: content, - Portal: portal, - }) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix accept message request") - return EventHandlingResultFailed.WithMSSError(err) - } - if portal.MessageRequest { - portal.MessageRequest = false - portal.UpdateBridgeInfo(ctx) - err = portal.Save(ctx) - if err != nil { - log.Err(err).Msg("Failed to save portal after accepting message request") - } - } - return EventHandlingResultSuccess.WithMSS() -} - -func (portal *Portal) autoAcceptMessageRequest( - ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures, -) error { - if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported { - return nil - } - mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) - if !ok { - return nil - } - err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ - Event: evt, - Content: &event.BeeperAcceptMessageRequestEventContent{ - IsImplicit: true, - }, - Portal: portal, - OrigSender: origSender, - }) - if err != nil { - return err - } - if portal.MessageRequest { - portal.MessageRequest = false - portal.UpdateBridgeInfo(ctx) - err = portal.Save(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request") - } - } - return nil -} - -func (portal *Portal) handleMatrixDeleteChat( - ctx context.Context, - sender *UserLogin, - origSender *OrigSender, - evt *event.Event, -) EventHandlingResult { - if origSender != nil { - return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser) - } - log := zerolog.Ctx(ctx) - content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - api, ok := sender.Client.(DeleteChatHandlingNetworkAPI) - if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) - } - err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{ - Event: evt, - Content: content, - Portal: portal, - }) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix chat delete") - return EventHandlingResultFailed.WithMSSError(err) - } - if portal.Receiver == "" { - _, others, err := portal.findOtherLogins(ctx, sender) - if err != nil { - log.Err(err).Msg("Failed to check if portal has other logins") - return EventHandlingResultFailed.WithError(err) - } else if len(others) > 0 { - log.Debug().Msg("Not deleting portal after chat delete as other logins are present") - return EventHandlingResultSuccess - } - } - err = portal.Delete(ctx) - if err != nil { - log.Err(err).Msg("Failed to delete portal from database") - return EventHandlingResultFailed.WithMSSError(err) - } - err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) - if err != nil { - log.Err(err).Msg("Failed to delete Matrix room") - return EventHandlingResultFailed.WithMSSError(err) - } - // No MSS here as the portal was deleted - return EventHandlingResultSuccess -} - func (portal *Portal) handleMatrixMembership( ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, - isStateRequest bool, -) EventHandlingResult { - if evt.StateKey == nil { - return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) - } +) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.MemberEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} if evt.Unsigned.PrevContent != nil { @@ -2017,22 +1195,26 @@ func (portal *Portal) handleMatrixMembership( }) api, ok := sender.Client.(MembershipHandlingNetworkAPI) if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrMembershipNotSupported) + portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) + return } targetMXID := id.UserID(*evt.StateKey) isSelf := sender.User.MXID == targetMXID target, err := portal.getTargetUser(ctx, targetMXID) if err != nil { log.Err(err).Msg("Failed to get member event target") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { log.Debug().Msg("Dropping leave event") - return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) + //portal.sendErrorStatus(ctx, evt, ErrIgnoringLeaveEvent) + return } targetGhost, _ := target.(*Ghost) + targetUserLogin, _ := target.(*UserLogin) membershipChange := &MatrixMembershipChange{ MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ @@ -2040,63 +1222,20 @@ func (portal *Portal) handleMatrixMembership( Content: content, Portal: portal, OrigSender: origSender, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - IsStateRequest: isStateRequest, - PrevContent: prevContent, + PrevContent: prevContent, }, - Target: target, - Type: membershipChangeType, + Target: target, + TargetGhost: targetGhost, + TargetUserLogin: targetUserLogin, + Type: membershipChangeType, } - res, err := api.HandleMatrixMembership(ctx, membershipChange) + _, err = api.HandleMatrixMembership(ctx, membershipChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } - didRedirectInvite := membershipChangeType == Invite && - targetGhost != nil && - res != nil && - res.RedirectTo != "" && - res.RedirectTo != targetGhost.ID - if didRedirectInvite { - log.Debug(). - Str("orig_id", string(targetGhost.ID)). - Str("redirect_id", string(res.RedirectTo)). - Msg("Invite was redirected to different ghost") - var redirectGhost *Ghost - redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo) - if err != nil { - log.Err(err).Msg("Failed to get redirect target ghost") - return EventHandlingResultFailed.WithError(err) - } - if !isStateRequest { - portal.sendRoomMeta( - ctx, - sender.User.DoublePuppet(ctx), - time.UnixMilli(evt.Timestamp), - event.StateMember, - evt.GetStateKey(), - &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo), - }, - true, - nil, - ) - } - portal.sendRoomMeta( - ctx, - sender.User.DoublePuppet(ctx), - time.UnixMilli(evt.Timestamp), - event.StateMember, - redirectGhost.Intent.GetMXID().String(), - content, - false, - nil, - ) - } - return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite) } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { @@ -2121,36 +1260,23 @@ func (portal *Portal) handleMatrixPowerLevels( sender *UserLogin, origSender *OrigSender, evt *event.Event, - isStateRequest bool, -) EventHandlingResult { - if evt.StateKey == nil || *evt.StateKey != "" { - return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) - } +) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - if content.CreateEvent == nil { - ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) - if ok { - var err error - content.CreateEvent, err = ars.GetStateEvent(ctx, portal.MXID, event.StateCreate, "") - if err != nil { - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err)) - } - } + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported) + portal.sendErrorStatus(ctx, evt, ErrPowerLevelsNotSupported) + return } prevContent := &event.PowerLevelsEventContent{} if evt.Unsigned.PrevContent != nil { _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent) - prevContent.CreateEvent = content.CreateEvent } plChange := &MatrixPowerLevelChange{ @@ -2160,11 +1286,8 @@ func (portal *Portal) handleMatrixPowerLevels( Content: content, Portal: portal, OrigSender: origSender, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - IsStateRequest: isStateRequest, - PrevContent: prevContent, + PrevContent: prevContent, }, Users: make(map[id.UserID]*UserPowerLevelChange), Events: make(map[string]*SinglePowerLevelChange), @@ -2205,269 +1328,18 @@ func (portal *Portal) handleMatrixPowerLevels( _, err := api.HandleMatrixPowerLevels(ctx, plChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix power level change") - return EventHandlingResultFailed.WithMSSError(err) - } - return EventHandlingResultSuccess.WithMSS() -} - -func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { - if evt.StateKey == nil || *evt.StateKey != "" || portal.MXID != evt.RoomID { - return EventHandlingResultIgnored - } - log := *zerolog.Ctx(ctx) - sentByBridge := evt.Sender == portal.Bridge.Bot.GetMXID() || portal.Bridge.IsGhostMXID(evt.Sender) - var senderUser *User - var err error - if !sentByBridge { - senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender) - if err != nil { - log.Err(err).Msg("Failed to get tombstone sender user") - return EventHandlingResultFailed.WithError(err) - } - } - content, ok := evt.Content.Parsed.(*event.TombstoneEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - log = log.With(). - Stringer("replacement_room", content.ReplacementRoom). - Logger() - if content.ReplacementRoom == "" { - log.Info().Msg("Received tombstone with no replacement room, cleaning up portal") - err := portal.RemoveMXID(ctx) - if err != nil { - log.Err(err).Msg("Failed to remove portal MXID") - return EventHandlingResultFailed.WithMSSError(err) - } - err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true) - if err != nil { - log.Err(err).Msg("Failed to clean up Matrix room") - return EventHandlingResultFailed.WithError(err) - } - return EventHandlingResultSuccess - } - existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID()) - if err != nil { - log.Err(err).Msg("Failed to get member info of bot in replacement room") - return EventHandlingResultFailed.WithError(err) - } - leaveOnError := func() { - if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin { - return - } - log.Debug().Msg("Leaving replacement room with bot after tombstone validation failed") - _, err = portal.Bridge.Bot.SendState( - ctx, - content.ReplacementRoom, - event.StateMember, - portal.Bridge.Bot.GetMXID().String(), - &event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: fmt.Sprintf("Failed to validate tombstone sent by %s from %s", evt.Sender, evt.RoomID), - }, - }, - time.Time{}, - ) - if err != nil { - log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed") - } - } - var via []string - if senderHS := evt.Sender.Homeserver(); senderHS != "" { - via = []string{senderHS} - } - err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via}) - if err != nil { - log.Err(err).Msg("Failed to join replacement room from tombstone") - return EventHandlingResultFailed.WithError(err) - } - if !sentByBridge && !senderUser.Permissions.Admin { - powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom) - if err != nil { - log.Err(err).Msg("Failed to get power levels in replacement room") - leaveOnError() - return EventHandlingResultFailed.WithError(err) - } - if powers.GetUserLevel(evt.Sender) < powers.Invite() { - log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room") - leaveOnError() - return EventHandlingResultIgnored - } - } - err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{ - DeleteOldRoom: true, - FetchInfoVia: senderUser, - }) - if errors.Is(err, ErrTargetRoomIsPortal) { - return EventHandlingResultIgnored - } else if err != nil { - return EventHandlingResultFailed.WithError(err) - } - return EventHandlingResultSuccess -} - -var ErrTargetRoomIsPortal = errors.New("target room is already a portal") -var ErrRoomAlreadyExists = errors.New("this portal already has a room") - -type UpdateMatrixRoomIDParams struct { - SyncDBMetadata func() - FailIfMXIDSet bool - OverwriteOldPortal bool - TombstoneOldRoom bool - DeleteOldRoom bool - - RoomCreateAlreadyLocked bool - - FetchInfoVia *User - ChatInfo *ChatInfo - ChatInfoSource *UserLogin -} - -func (portal *Portal) UpdateMatrixRoomID( - ctx context.Context, - newRoomID id.RoomID, - params UpdateMatrixRoomIDParams, -) error { - if !params.RoomCreateAlreadyLocked { - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - } - oldRoom := portal.MXID - if oldRoom == newRoomID { - return nil - } else if oldRoom != "" && params.FailIfMXIDSet { - return ErrRoomAlreadyExists - } - log := zerolog.Ctx(ctx) - portal.Bridge.cacheLock.Lock() - // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns - // and unlock it before return if nothing goes wrong. - unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock) - defer unlockCacheLock() - if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal { - log.Warn().Msg("Replacement room is already a portal, ignoring") - return ErrTargetRoomIsPortal - } else if alreadyExists { - log.Debug().Msg("Replacement room is already a portal, overwriting") - existingPortal.MXID = "" - existingPortal.RoomCreated.Clear() - err := existingPortal.Save(ctx) - if err != nil { - return fmt.Errorf("failed to clear mxid of existing portal: %w", err) - } - delete(portal.Bridge.portalsByMXID, portal.MXID) - } - portal.MXID = newRoomID - portal.RoomCreated.Set() - portal.Bridge.portalsByMXID[portal.MXID] = portal - portal.NameSet = false - portal.AvatarSet = false - portal.TopicSet = false - portal.InSpace = false - portal.CapState = database.CapabilityState{} - portal.lastCapUpdate = time.Time{} - if params.SyncDBMetadata != nil { - params.SyncDBMetadata() - } - unlockCacheLock() - portal.updateLogger() - - err := portal.Save(ctx) - if err != nil { - log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID") - return err - } - log.Info().Msg("Successfully followed tombstone and updated portal MXID") - err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID") - } - go portal.addToUserSpaces(ctx) - if params.FetchInfoVia != nil { - go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia) - } else if params.ChatInfo != nil { - go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{}) - } else if params.ChatInfoSource != nil { - portal.UpdateCapabilities(ctx, params.ChatInfoSource, true) - portal.UpdateBridgeInfo(ctx) - } - go func() { - // TODO this might become unnecessary if UpdateInfo starts taking care of it - _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ - Parsed: &event.ElementFunctionalMembersContent{ - ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()}, - }, - }, time.Time{}) - if err != nil { - if err != nil { - log.Warn().Err(err).Msg("Failed to set service members in new room") - } - } - }() - if params.TombstoneOldRoom && oldRoom != "" { - _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{ - Parsed: &event.TombstoneEventContent{ - Body: "Room has been replaced.", - ReplacementRoom: newRoomID, - }, - }, time.Now()) - if err != nil { - log.Err(err).Msg("Failed to send tombstone event to old room") - } - } - if params.DeleteOldRoom && oldRoom != "" { - go func() { - err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true) - if err != nil { - log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID") - } - }() - } - return nil -} - -func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) { - log := zerolog.Ctx(ctx) - logins, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user logins in portal to sync info") + portal.sendErrorStatus(ctx, evt, err) return } - var preferredLogin *UserLogin - for _, login := range logins { - if !login.Client.IsLoggedIn() { - continue - } else if preferredLogin == nil { - preferredLogin = login - } else if senderUser != nil && login.User == senderUser { - preferredLogin = login - } - } - if preferredLogin == nil { - log.Warn().Msg("No logins found to sync info") - return - } - info, err := preferredLogin.Client.GetChatInfo(ctx, portal) - if err != nil { - log.Err(err).Msg("Failed to get chat info") - return - } - log.Info(). - Str("info_source_login", string(preferredLogin.ID)). - Msg("Fetched info to update portal after tombstone") - portal.UpdateInfo(ctx, info, preferredLogin, nil, time.Time{}) } -func (portal *Portal) handleMatrixRedaction( - ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, -) EventHandlingResult { +func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return } if evt.Redacts != "" && content.Redacts != evt.Redacts { content.Redacts = evt.Redacts @@ -2479,17 +1351,20 @@ func (portal *Portal) handleMatrixRedaction( reactingAPI, reactOK := sender.Client.(ReactionHandlingNetworkAPI) if !deleteOK && !reactOK { log.Debug().Msg("Ignoring redaction without checking target as network connector doesn't implement RedactionHandlingNetworkAPI nor ReactionHandlingNetworkAPI") - return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) + portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) + return } var redactionTargetReaction *database.Reaction redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) + return } else if redactionTargetMsg != nil { if !deleteOK { log.Debug().Msg("Ignoring message redaction event as network connector doesn't implement RedactionHandlingNetworkAPI") - return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) + portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) + return } err = deletingAPI.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ @@ -2497,18 +1372,18 @@ func (portal *Portal) handleMatrixRedaction( Content: content, Portal: portal, OrigSender: origSender, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, TargetMessage: redactionTargetMsg, }) } else if redactionTargetReaction, err = portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { log.Err(err).Msg("Failed to get redaction target reaction from database") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) + return } else if redactionTargetReaction != nil { if !reactOK { log.Debug().Msg("Ignoring reaction redaction event as network connector doesn't implement ReactionHandlingNetworkAPI") - return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) + portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) + return } // TODO ignore if sender doesn't match? err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ @@ -2517,30 +1392,52 @@ func (portal *Portal) handleMatrixRedaction( Content: content, Portal: portal, OrigSender: origSender, - - InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, TargetReaction: redactionTargetReaction, }) } else { log.Debug().Msg("Redaction target message not found in database") - return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) + portal.sendErrorStatus(ctx, evt, fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) + return } if err != nil { log.Err(err).Msg("Failed to handle Matrix redaction") - return EventHandlingResultFailed.WithMSSError(err) + portal.sendErrorStatus(ctx, evt, err) + return } // TODO delete msg/reaction db row - return EventHandlingResultSuccess.WithMSS() + portal.sendSuccessStatus(ctx, evt) } -func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { - log := zerolog.Ctx(ctx) +func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { + log := portal.Log.With(). + Str("source_id", string(source.ID)). + Str("action", "handle remote event"). + Logger() + defer func() { + if err := recover(); err != nil { + logEvt := log.Error() + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt. + Bytes("stack", debug.Stack()). + Msg("Remote event handler panicked") + } + }() + evtType := evt.GetType() + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + c = c.Stringer("bridge_evt_type", evtType) + return evt.AddLogContext(c) + }) + ctx := log.WithContext(context.TODO()) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { log.Debug().Msg("Dropping event as portal doesn't exist") - return EventHandlingResultIgnored + return } infoProvider, ok := mcp.(RemoteChatResyncWithInfo) var info *ChatInfo @@ -2559,15 +1456,12 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, err = portal.createMatrixRoomInLoop(ctx, source, info, bundle) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") - return EventHandlingResultFailed.WithError(err) + // TODO error + return } if evtType == RemoteEventChatResync { log.Debug().Msg("Not handling chat resync event further as portal was created by it") - postHandler, ok := evt.(RemotePostHandler) - if ok { - postHandler.PostHandle(ctx, portal) - } - return EventHandlingResultSuccess + return } } preHandler, ok := evt.(RemotePreHandler) @@ -2578,109 +1472,57 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") - res = EventHandlingResultIgnored case RemoteEventMessage, RemoteEventMessageUpsert: - res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) + portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: - res = portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) + portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) case RemoteEventReaction: - res = portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) + portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) case RemoteEventReactionRemove: - res = portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) + portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) case RemoteEventReactionSync: - res = portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) + portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) case RemoteEventMessageRemove: - res = portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) case RemoteEventReadReceipt: - res = portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) + portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) case RemoteEventMarkUnread: - res = portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) + portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) case RemoteEventDeliveryReceipt: - res = portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) + portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) case RemoteEventTyping: - res = portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) + portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) case RemoteEventChatInfoChange: - res = portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) + portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) case RemoteEventChatResync: - res = portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) + portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) case RemoteEventChatDelete: - res = portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) + portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) case RemoteEventBackfill: - res = portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) + portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) default: log.Warn().Msg("Got remote event with unknown type") } - postHandler, ok := evt.(RemotePostHandler) - if ok { - postHandler.PostHandle(ctx, portal) - } - return } -func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { - if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" { - return - } - ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) - if !ok { - return - } - portal.functionalMembersLock.Lock() - defer portal.functionalMembersLock.Unlock() - var functionalMembers *event.ElementFunctionalMembersContent - if portal.functionalMembersCache != nil { - functionalMembers = portal.functionalMembersCache - } else { - evt, err := ars.GetStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "") - if err != nil && !errors.Is(err, mautrix.MNotFound) { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get functional members state event") - return - } - functionalMembers = &event.ElementFunctionalMembersContent{} - if evt != nil { - evtContent, ok := evt.Content.Parsed.(*event.ElementFunctionalMembersContent) - if ok && evtContent != nil { - functionalMembers = evtContent - } - } - } - // TODO what about non-double-puppeted user ghosts? - functionalMembers.Add(portal.Bridge.Bot.GetMXID()) - if functionalMembers.Add(ghost.Intent.GetMXID()) { - _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ - Parsed: functionalMembers, - }, time.Time{}) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to update functional members state event") - return - } - } -} - -func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { +func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { var ghost *Ghost if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID { zerolog.Ctx(ctx).Warn(). Str("original_id", string(sender.Sender)). Str("default_other_user", string(portal.OtherUserID)). Msg("Overriding event sender with primary other user in DM portal") - // Ensure the ghost row exists anyway to prevent foreign key errors when saving messages - // TODO it'd probably be better to override the sender in the saved message, but that's more effort - _, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) - if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get ghost with original user ID") - return - } sender.Sender = portal.OtherUserID } if sender.Sender != "" { + var err error ghost, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") return + } else { + ghost.UpdateInfoIfNecessary(ctx, source, evtType) } - ghost.UpdateInfoIfNecessary(ctx, source, evtType) - portal.ensureFunctionalMember(ctx, ghost) } if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) @@ -2715,90 +1557,58 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS return } -func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) (MatrixAPI, bool) { - intent, _, err := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) - if err != nil { - return nil, false - } +func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { + intent, _ := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) if intent == nil { // TODO this is very hacky - we should either insert an empty ghost row automatically // (and not fetch it at runtime) or make the message sender column nullable. portal.Bridge.GetGhostByID(ctx, "") intent = portal.Bridge.Bot - if intent == nil { - panic(fmt.Errorf("bridge bot is nil")) - } } - return intent, true + return intent } -func (portal *Portal) getRelationMeta( - ctx context.Context, - currentMsgID networkid.MessageID, - currentMsg *ConvertedMessage, - isBatchSend bool, -) (replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { log := zerolog.Ctx(ctx) var err error - if currentMsg.ReplyTo != nil { - replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *currentMsg.ReplyTo) + if replyToPtr != nil { + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *replyToPtr) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { - if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { + if isBatchSend { // This is somewhat evil replyTo = &database.Message{ - MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, currentMsg.ReplyTo.MessageID, ptr.Val(currentMsg.ReplyTo.PartID)), - Room: currentMsg.ReplyToRoom, - SenderID: currentMsg.ReplyToUser, - } - if currentMsg.ReplyToLogin != "" && (portal.Receiver == "" || portal.Receiver == currentMsg.ReplyToLogin) { - userLogin, err := portal.Bridge.GetExistingUserLoginByID(ctx, currentMsg.ReplyToLogin) - if err != nil { - log.Err(err). - Str("reply_to_login", string(currentMsg.ReplyToLogin)). - Msg("Failed to get reply target user login") - } else if userLogin != nil { - replyTo.SenderMXID = userLogin.UserMXID - } - } else { - ghost, err := portal.Bridge.GetGhostByID(ctx, currentMsg.ReplyToUser) - if err != nil { - log.Err(err). - Str("reply_to_user_id", string(currentMsg.ReplyToUser)). - Msg("Failed to get reply target ghost") - } else { - replyTo.SenderMXID = ghost.Intent.GetMXID() - } + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, replyToPtr.MessageID, ptr.Val(replyToPtr.PartID)), } } else { - log.Warn().Any("reply_to", *currentMsg.ReplyTo).Msg("Reply target message not found in database") + log.Warn().Any("reply_to", *replyToPtr).Msg("Reply target message not found in database") } } } - if currentMsg.ThreadRoot != nil && *currentMsg.ThreadRoot != currentMsgID { - threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot) + if threadRootPtr != nil && *threadRootPtr != currentMsg { + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *threadRootPtr) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { - if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { + if isBatchSend { threadRoot = &database.Message{ - MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *currentMsg.ThreadRoot, ""), + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *threadRootPtr, ""), } } else { - log.Warn().Str("thread_root", string(*currentMsg.ThreadRoot)).Msg("Thread root message not found in database") + log.Warn().Str("thread_root", string(*threadRootPtr)).Msg("Thread root message not found in database") } - } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot); err != nil { + } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *threadRootPtr); err != nil { log.Err(err).Msg("Failed to get last thread message from database") } if prevThreadEvent == nil { - prevThreadEvent = ptr.Clone(threadRoot) + prevThreadEvent = threadRoot } } return } -func (portal *Portal) applyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { if content.Mentions == nil { content.Mentions = &event.Mentions{} } @@ -2806,24 +1616,7 @@ func (portal *Portal) applyRelationMeta(ctx context.Context, content *event.Mess content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) } if replyTo != nil { - crossRoom := !replyTo.Room.IsEmpty() && replyTo.Room != portal.PortalKey - if !crossRoom || portal.Bridge.Config.CrossRoomReplies { - content.GetRelatesTo().SetReplyTo(replyTo.MXID) - } - if crossRoom && portal.Bridge.Config.CrossRoomReplies { - targetPortal, err := portal.Bridge.GetExistingPortalByKey(ctx, replyTo.Room) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Object("target_portal_key", replyTo.Room). - Msg("Failed to get cross-room reply portal") - } else if targetPortal == nil || targetPortal.MXID == "" { - zerolog.Ctx(ctx).Warn(). - Object("target_portal_key", replyTo.Room). - Msg("Cross-room reply portal not found") - } else { - content.RelatesTo.InReplyTo.UnstableRoomID = targetPortal.MXID - } - } + content.GetRelatesTo().SetReplyTo(replyTo.MXID) content.Mentions.Add(replyTo.SenderMXID) } } @@ -2837,32 +1630,27 @@ func (portal *Portal) sendConvertedMessage( ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event, -) ([]*database.Message, EventHandlingResult) { +) []*database.Message { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e } } log := zerolog.Ctx(ctx) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( - ctx, id, converted, false, - ) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) output := make([]*database.Message, 0, len(converted.Parts)) - allSuccess := true for i, part := range converted.Parts { - portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) - part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent() + portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) dbMessage := &database.Message{ - ID: id, - PartID: part.ID, - Room: portal.PortalKey, - SenderID: senderID, - SenderMXID: intent.GetMXID(), - Timestamp: ts, - ThreadRoot: ptr.Val(converted.ThreadRoot), - ReplyTo: ptr.Val(converted.ReplyTo), - Metadata: part.DBMetadata, - IsDoublePuppeted: intent.IsDoublePuppet(), + ID: id, + PartID: part.ID, + Room: portal.PortalKey, + SenderID: senderID, + SenderMXID: intent.GetMXID(), + Timestamp: ts, + ThreadRoot: ptr.Val(converted.ThreadRoot), + ReplyTo: ptr.Val(converted.ReplyTo), + Metadata: part.DBMetadata, } if part.DontBridge { dbMessage.SetFakeMXID() @@ -2882,7 +1670,6 @@ func (portal *Portal) sendConvertedMessage( }) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") - allSuccess = false continue } logContext(log.Debug()). @@ -2894,16 +1681,14 @@ func (portal *Portal) sendConvertedMessage( err := portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") - allSuccess = false } - if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() { - if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { + if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() { + if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) } - portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, - Timestamp: dbMessage.Timestamp, DisappearingSetting: converted.Disappear, }) } @@ -2912,10 +1697,7 @@ func (portal *Portal) sendConvertedMessage( } output = append(output, dbMessage) } - if !allSuccess { - return output, EventHandlingResultFailed - } - return output, EventHandlingResultSuccess + return output } func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { @@ -2932,8 +1714,6 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage pending, ok := portal.outgoingMessages[txnID] if !ok { return false, nil - } else if pending.ignore { - return true, nil } delete(portal.outgoingMessages, txnID) pending.db.ID = evt.GetID() @@ -2953,9 +1733,6 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage saveMessage, statusErr = pending.handle(evt, pending.db) } if saveMessage { - if portal.Bridge.Config.OutgoingMessageReID { - pending.db.MXID = portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, pending.db.ID, pending.db.PartID) - } // Hack to ensure the ghost row exists // TODO move to better place (like login) portal.Bridge.GetGhostByID(ctx, pending.db.SenderID) @@ -2968,61 +1745,46 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage if statusErr != nil { portal.sendErrorStatus(ctx, pending.evt, statusErr) } else { - portal.sendSuccessStatus(ctx, pending.evt, getStreamOrder(evt), pending.evt.ID) + portal.sendSuccessStatus(ctx, pending.evt) } } zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") return true, pending.db } -func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { +func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { log := zerolog.Ctx(ctx) - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) - if !ok { - return + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) + if intent == nil { + return false } res, err := evt.HandleExisting(ctx, portal, intent, existing) if err != nil { log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") - } else { - handleRes = EventHandlingResultSuccess } if res.SaveParts { for _, part := range existing { err = portal.Bridge.DB.Message.Update(ctx, part) if err != nil { log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database") - handleRes = EventHandlingResultFailed.WithError(err) } } } if len(res.SubEvents) > 0 { for _, subEvt := range res.SubEvents { - subType := subEvt.GetType() - log := portal.Log.With(). - Str("source_id", string(source.ID)). - Str("action", "handle remote subevent"). - Stringer("bridge_evt_type", subType). - Logger() - subRes := portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) - if !subRes.Success { - handleRes.Success = false - } + portal.handleRemoteEvent(source, subEvt) } } - continueHandling = res.ContinueMessageHandling - return + return res.ContinueMessageHandling } -func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { +func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { log := zerolog.Ctx(ctx) upsertEvt, isUpsert := evt.(RemoteMessageUpsert) isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { - if isUpsert && dbMessage != nil { - res, _ = portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) - } else { - res = EventHandlingResultIgnored + if isUpsert { + portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) } return } @@ -3031,42 +1793,32 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Err(err).Msg("Failed to check if message is a duplicate") } else if len(existing) > 0 { if isUpsert { - var continueHandling bool - res, continueHandling = portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) - if continueHandling { + if portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) { log.Debug().Msg("Upsert handler said to continue message handling normally") } else { - return res + return } } else { log.Debug().Stringer("existing_mxid", existing[0].MXID).Msg("Ignoring duplicate message") - return EventHandlingResultIgnored + return } } - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) - if !ok { - return EventHandlingResultFailed + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) + if intent == nil { + return } ts := getEventTS(evt) converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { if errors.Is(err, ErrIgnoringRemoteEvent) { log.Debug().Err(err).Msg("Remote message handling was cancelled by convert function") - return EventHandlingResultIgnored } else { log.Err(err).Msg("Failed to convert remote message") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") - return EventHandlingResultFailed.WithError(err) } + return } - _, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) - if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { - err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) - if err != nil { - log.Warn().Err(err).Msg("Failed to send stop typing event after bridging message") - } - } - return + portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { @@ -3089,7 +1841,7 @@ func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAP } } -func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { +func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { log := zerolog.Ctx(ctx) var existing []*database.Message if bundledEvt, ok := evt.(RemoteEventWithBundledParts); ok { @@ -3101,41 +1853,28 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID) if err != nil { log.Err(err).Msg("Failed to get edit target message") - return EventHandlingResultFailed.WithError(err) + return } } if existing == nil { log.Warn().Msg("Edit target message not found") - return EventHandlingResultIgnored + return } - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) - if !ok { - return EventHandlingResultFailed - } else if intent.GetMXID() != existing[0].SenderMXID { - log.Warn(). - Stringer("edit_sender_mxid", intent.GetMXID()). - Stringer("original_sender_mxid", existing[0].SenderMXID). - Msg("Not bridging edit: sender doesn't match original message sender") - return EventHandlingResultIgnored + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) + if intent == nil { + return } ts := getEventTS(evt) converted, err := evt.ConvertEdit(ctx, portal, intent, existing) if errors.Is(err, ErrIgnoringRemoteEvent) { log.Debug().Err(err).Msg("Remote edit handling was cancelled by convert function") - return EventHandlingResultIgnored + return } else if err != nil { log.Err(err).Msg("Failed to convert remote edit") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") - return EventHandlingResultFailed.WithError(err) + return } - res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) - if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { - err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) - if err != nil { - log.Warn().Err(err).Msg("Failed to send stop typing event after bridging edit") - } - } - return res + portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) } func (portal *Portal) sendConvertedEdit( @@ -3146,9 +1885,8 @@ func (portal *Portal) sendConvertedEdit( intent MatrixAPI, ts time.Time, streamOrder int64, -) EventHandlingResult { +) { log := zerolog.Ctx(ctx) - allSuccess := true for i, part := range converted.ModifiedParts { if part.Content.Mentions == nil { part.Content.Mentions = &event.Mentions{} @@ -3184,7 +1922,6 @@ func (portal *Portal) sendConvertedEdit( }) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") - allSuccess = false continue } else { log.Debug(). @@ -3199,7 +1936,6 @@ func (portal *Portal) sendConvertedEdit( err := portal.Bridge.DB.Message.Update(ctx, part.Part) if err != nil { log.Err(err).Int64("part_rowid", part.Part.RowID).Msg("Failed to update message part in database") - allSuccess = false } } for _, part := range converted.DeletedParts { @@ -3213,7 +1949,6 @@ func (portal *Portal) sendConvertedEdit( }) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") - allSuccess = false } else { log.Debug(). Stringer("redaction_event_id", resp.EventID). @@ -3224,19 +1959,11 @@ func (portal *Portal) sendConvertedEdit( err = portal.Bridge.DB.Message.Delete(ctx, part.RowID) if err != nil { log.Err(err).Int64("part_rowid", part.RowID).Msg("Failed to delete message part from database") - allSuccess = false } } if converted.AddedParts != nil { - _, res := portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) - if !res.Success { - allSuccess = false - } + portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) } - if !allSuccess { - return EventHandlingResultFailed - } - return EventHandlingResultSuccess } func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -3249,9 +1976,9 @@ func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventW func (portal *Portal) getTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) { if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { - return portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + return portal.Bridge.DB.Reaction.GetByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) } else { - return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) } } @@ -3269,27 +1996,23 @@ func getStreamOrder(evt RemoteEvent) int64 { return 0 } -func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { +func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { log := zerolog.Ctx(ctx) eventTS := getEventTS(evt) targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return EventHandlingResultFailed.WithError(err) + return } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") - return EventHandlingResultIgnored + return } var existingReactions []*database.Reaction if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { - existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) } else { - existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage()) - } - if err != nil { - log.Err(err).Msg("Failed to get existing reactions for reaction sync") - return EventHandlingResultFailed.WithError(err) + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, evt.GetTargetMessage()) } existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) for _, existingReaction := range existingReactions { @@ -3299,14 +2022,8 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existing[existingReaction.SenderID][existingReaction.EmojiID] = existingReaction } - doAddReaction := func(new *BackfillReaction, intent MatrixAPI) { - if intent == nil { - var ok bool - intent, ok = portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) - if !ok { - return - } - } + doAddReaction := func(new *BackfillReaction) MatrixAPI { + intent := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) portal.sendConvertedReaction( ctx, new.Sender.Sender, intent, targetMessage, new.EmojiID, new.Emoji, new.Timestamp, new.DBMetadata, new.ExtraContent, @@ -3316,6 +2033,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User Time("reaction_ts", new.Timestamp) }, ) + return intent } doRemoveReaction := func(old *database.Reaction, intent MatrixAPI, deleteRow bool) { if intent == nil && old.SenderMXID != "" { @@ -3349,12 +2067,8 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { - intent, ok := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) - if !ok { - return - } + intent := doAddReaction(new) doRemoveReaction(old, intent, false) - doAddReaction(new, intent) } newData := evt.GetReactions() @@ -3368,12 +2082,12 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existingReaction, ok := existingUserReactions[reaction.EmojiID] if ok { delete(existingUserReactions, reaction.EmojiID) - if reaction.EmojiID != "" || reaction.Emoji == existingReaction.Emoji { + if reaction.EmojiID != "" { continue } doOverwriteReaction(reaction, existingReaction) } else { - doAddReaction(reaction, nil) + doAddReaction(reaction) } } totalReactionCount := len(existingUserReactions) + len(reactions.Reactions) @@ -3403,34 +2117,30 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { +func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { log := zerolog.Ctx(ctx) targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return EventHandlingResultFailed.WithError(err) + return } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") - return EventHandlingResultIgnored + return } emoji, emojiID := evt.GetReactionEmoji() - existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) + existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") - return EventHandlingResultFailed.WithError(err) + return } else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) { log.Debug().Msg("Ignoring duplicate reaction") - return EventHandlingResultIgnored + return } ts := getEventTS(evt) - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) - if !ok { - return EventHandlingResultFailed - } + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) var extra map[string]any if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { extra = extraContentProvider.GetReactionExtraContent() @@ -3439,6 +2149,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { dbMetadata = metaProvider.GetReactionDBMetadata() } + portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) if existingReaction != nil { _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ @@ -3449,14 +2160,13 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Err(err).Msg("Failed to redact old reaction") } } - return portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) } func (portal *Portal) sendConvertedReaction( ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event, -) EventHandlingResult { +) { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -3491,7 +2201,7 @@ func (portal *Portal) sendConvertedReaction( }) if err != nil { logContext(log.Err(err)).Msg("Failed to send reaction to Matrix") - return EventHandlingResultFailed.WithError(err) + return } logContext(log.Debug()). Stringer("event_id", resp.EventID). @@ -3500,9 +2210,7 @@ func (portal *Portal) sendConvertedReaction( err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { logContext(log.Err(err)).Msg("Failed to save reaction to database") - return EventHandlingResultFailed.WithError(err) } - return EventHandlingResultSuccess } func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { @@ -3521,26 +2229,22 @@ func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (M } } -func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { +func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { log := zerolog.Ctx(ctx) targetReaction, err := portal.getTargetReaction(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target reaction for removal") - return EventHandlingResultFailed.WithError(err) + return } else if targetReaction == nil { log.Warn().Msg("Target reaction not found") - return EventHandlingResultIgnored + return } intent, err := portal.getIntentForMXID(ctx, targetReaction.SenderMXID) if err != nil { log.Err(err).Stringer("sender_mxid", targetReaction.SenderMXID).Msg("Failed to get intent for removing reaction") } if intent == nil { - var ok bool - intent, ok = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) - if !ok { - return EventHandlingResultFailed - } + intent = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) } ts := getEventTS(evt) _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -3550,42 +2254,24 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us }, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction}) if err != nil { log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") - return EventHandlingResultFailed.WithError(err) } err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction) if err != nil { log.Err(err).Msg("Failed to delete target reaction from database") } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { +func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { log := zerolog.Ctx(ctx) targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to get target message for removal") - return EventHandlingResultFailed.WithError(err) + return } else if len(targetParts) == 0 { log.Debug().Msg("Target message not found") - return EventHandlingResultIgnored - } - onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) - onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() - if onlyForMe && portal.Receiver == "" { - _, others, err := portal.findOtherLogins(ctx, source) - if err != nil { - log.Err(err).Msg("Failed to check if portal has other logins") - return EventHandlingResultFailed.WithError(err) - } else if len(others) > 0 { - log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") - return EventHandlingResultIgnored - } - } - - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) - if !ok { - return EventHandlingResultFailed + return } + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) if intent == portal.Bridge.Bot && len(targetParts) > 0 { senderIntent, err := portal.getIntentForMXID(ctx, targetParts[0].SenderMXID) if err != nil { @@ -3594,17 +2280,15 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use intent = senderIntent } } - res := portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) + portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to delete target message from database") } - return res } -func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { +func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { log := zerolog.Ctx(ctx) - var anyFailed bool for _, part := range parts { if part.HasFakeMXID() { continue @@ -3616,7 +2300,6 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. }, &MatrixSendExtra{Timestamp: ts, MessageMeta: part}) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") - anyFailed = true } else { log.Debug(). Stringer("redaction_event_id", resp.EventID). @@ -3625,33 +2308,22 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. Msg("Sent redaction of message part to Matrix") } } - if anyFailed { - return EventHandlingResultFailed - } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { +func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { + // TODO exclude fake mxids log := zerolog.Ctx(ctx) var err error var lastTarget *database.Message - readUpTo := evt.GetReadUpTo() if lastTargetID := evt.GetLastReceiptTarget(); lastTargetID != "" { lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, lastTargetID) if err != nil { log.Err(err).Str("last_target_id", string(lastTargetID)). Msg("Failed to get last target message for read receipt") - return EventHandlingResultFailed.WithError(err) + return } else if lastTarget == nil { log.Debug().Str("last_target_id", string(lastTargetID)). Msg("Last target message not found") - } else if lastTarget.HasFakeMXID() { - log.Debug().Str("last_target_id", string(lastTargetID)). - Msg("Last target message is fake") - if readUpTo.IsZero() { - readUpTo = lastTarget.Timestamp - } - lastTarget = nil } } if lastTarget == nil { @@ -3660,89 +2332,62 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL if err != nil { log.Err(err).Str("target_id", string(targetID)). Msg("Failed to get target message for read receipt") - return EventHandlingResultFailed.WithError(err) - } else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { + return + } else if target != nil && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { lastTarget = target } } } + readUpTo := evt.GetReadUpTo() if lastTarget == nil && !readUpTo.IsZero() { - lastTarget, err = portal.Bridge.DB.Message.GetLastNonFakePartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) + lastTarget, err = portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) if err != nil { log.Err(err).Time("read_up_to", readUpTo).Msg("Failed to get target message for read receipt") } } - sender := evt.GetSender() - intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) - if !ok { - return EventHandlingResultFailed - } - var addTargetLog func(evt *zerolog.Event) *zerolog.Event if lastTarget == nil { - sevt, evtOK := evt.(RemoteReadReceiptWithStreamOrder) - soIntent, soIntentOK := intent.(StreamOrderReadingMatrixAPI) - if !evtOK || !soIntentOK || sevt.GetReadUpToStreamOrder() == 0 { - log.Warn().Msg("No target message found for read receipt") - return EventHandlingResultIgnored - } - targetStreamOrder := sevt.GetReadUpToStreamOrder() - addTargetLog = func(evt *zerolog.Event) *zerolog.Event { - return evt.Int64("target_stream_order", targetStreamOrder) - } - err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) - if readUpTo.IsZero() { - readUpTo = getEventTS(evt) - } - } else { - addTargetLog = func(evt *zerolog.Event) *zerolog.Event { - return evt.Stringer("target_mxid", lastTarget.MXID) - } - err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) - readUpTo = lastTarget.Timestamp + log.Warn().Msg("No target message found for read receipt") + return } + sender := evt.GetSender() + intent := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) + err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) if err != nil { - addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") - return EventHandlingResultFailed.WithError(err) + log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") } else { - addTargetLog(log.Debug()).Msg("Bridged read receipt") + log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt") } if sender.IsFromMe { - portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo) + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { +func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { if !evt.GetSender().IsFromMe { zerolog.Ctx(ctx).Warn().Msg("Ignoring mark unread event from non-self user") - return EventHandlingResultIgnored + return } dp := source.User.DoublePuppet(ctx) if dp == nil { - return EventHandlingResultIgnored + return } err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event") - return EventHandlingResultFailed.WithError(err) } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { - if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") { - return EventHandlingResultIgnored - } - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) - if !ok { - return EventHandlingResultFailed +func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { + if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { + return } + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) log := zerolog.Ctx(ctx) for _, target := range evt.GetReceiptTargets() { targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target) if err != nil { log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt") - return EventHandlingResultFailed.WithError(err) + continue } else if len(targetParts) == 0 { continue } else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost { @@ -3753,51 +2398,36 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U Status: event.MessageStatusSuccess, DeliveredTo: []id.UserID{intent.GetMXID()}, }, &MessageStatusEventInfo{ - RoomID: portal.MXID, - SourceEventID: part.MXID, - Sender: part.SenderMXID, - - IsSourceEventDoublePuppeted: part.IsDoublePuppeted, + RoomID: portal.MXID, + EventID: part.MXID, + Sender: part.SenderMXID, }) } } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { +func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { var typingType TypingType if typedEvt, ok := evt.(RemoteTypingWithType); ok { typingType = typedEvt.GetTypingType() } - intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) - if !ok { - return EventHandlingResultFailed - } - timeout := evt.GetTimeout() - err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) + err := intent.MarkTyping(ctx, portal.MXID, typingType, evt.GetTimeout()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") - return EventHandlingResultFailed.WithError(err) } - if timeout == 0 { - portal.currentlyTypingGhosts.Remove(intent.GetMXID()) - } else { - portal.currentlyTypingGhosts.Add(intent.GetMXID()) - } - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { +func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { info, err := evt.GetChatInfoChange(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change") - return EventHandlingResultFailed.WithError(err) + return } portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) - return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { +func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { log := zerolog.Ctx(ctx) infoProvider, ok := evt.(RemoteChatResyncWithInfo) if ok { @@ -3806,12 +2436,10 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo log.Err(err).Msg("Failed to get chat info from resync event") } else if info != nil { portal.UpdateInfo(ctx, info, source, nil, time.Time{}) - } else { - log.Debug().Msg("No chat info provided in resync event") } } backfillChecker, ok := evt.(RemoteChatResyncBackfill) - if portal.Bridge.Config.Backfill.Enabled && ok && portal.RoomType != database.RoomTypeSpace { + if portal.Bridge.Config.Backfill.Enabled && ok { latestMessage, err := portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, time.Now().Add(10*time.Second)) if err != nil { log.Err(err).Msg("Failed to get last message in portal to check if backfill is necessary") @@ -3826,120 +2454,29 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo portal.doForwardBackfill(ctx, source, latestMessage, bundle) } } - return EventHandlingResultSuccess } -func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { - others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) - if err != nil { - return - } - others = slices.DeleteFunc(others, func(up *database.UserPortal) bool { - if up.LoginID == source.ID { - ownUP = up - return true - } - return false - }) - return -} - -type childDeleteProxy struct { - RemoteChatDeleteWithChildren - child networkid.PortalKey - done func() -} - -func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context { - return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children") -} -func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child } -func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false } -func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {} -func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() } - -func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { - log := zerolog.Ctx(ctx) +func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { if portal.Receiver == "" && evt.DeleteOnlyForMe() { - ownUP, logins, err := portal.findOtherLogins(ctx, source) - if err != nil { - log.Err(err).Msg("Failed to check if portal has other logins") - return EventHandlingResultFailed.WithError(err) - } - if len(logins) > 0 { - log.Debug().Msg("Not deleting portal with other logins in remote chat delete event") - if ownUP != nil { - err = portal.Bridge.DB.UserPortal.Delete(ctx, ownUP) - if err != nil { - log.Err(err).Msg("Failed to delete own user portal row from database") - } else { - log.Debug().Msg("Deleted own user portal row from database") - } - } - _, err = portal.sendStateWithIntentOrBot( - ctx, - source.User.DoublePuppet(ctx), - event.StateMember, - source.UserMXID.String(), - &event.Content{Parsed: &event.MemberEventContent{Membership: event.MembershipLeave}}, - getEventTS(evt), - ) - if err != nil { - log.Err(err).Msg("Failed to send leave state event for user after remote chat delete") - return EventHandlingResultFailed.WithError(err) - } else { - log.Debug().Msg("Sent leave state event for user after remote chat delete") - return EventHandlingResultSuccess - } - } - } - if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace { - children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to fetch children to delete") - return EventHandlingResultFailed.WithError(err) - } - log.Debug(). - Int("portal_count", len(children)). - Msg("Deleting child portals before remote chat delete") - var wg sync.WaitGroup - wg.Add(len(children)) - for _, child := range children { - child.queueEvent(ctx, &portalRemoteEvent{ - evt: &childDeleteProxy{ - RemoteChatDeleteWithChildren: childDeleter, - child: child.PortalKey, - done: wg.Done, - }, - source: source, - evtType: RemoteEventChatDelete, - }) - } - wg.Wait() - log.Debug().Msg("Finished deleting child portals") + // TODO check if there are other users } err := portal.Delete(ctx) if err != nil { - log.Err(err).Msg("Failed to delete portal from database") - return EventHandlingResultFailed.WithError(err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete portal from database") + return } err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) if err != nil { - log.Err(err).Msg("Failed to delete Matrix room") - return EventHandlingResultFailed.WithError(err) - } else { - log.Info().Msg("Deleted room after remote chat delete event") - return EventHandlingResultSuccess + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete Matrix room") } } -func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { +func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { //data, err := backfill.GetBackfillData(ctx, portal) //if err != nil { // zerolog.Ctx(ctx).Err(err).Msg("Failed to get backfill data") // return //} - return } type ChatInfoChange struct { @@ -3952,10 +2489,7 @@ type ChatInfoChange struct { } func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { - intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) - if !ok { - return - } + intent := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) if change.ChatInfo != nil { portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) } @@ -3973,45 +2507,13 @@ type PortalInfo = ChatInfo type ChatMember struct { EventSender Membership event.Membership - // Per-room nickname for the user. Not yet used. - Nickname *string - // The power level to set for the user when syncing power levels. + Nickname *string PowerLevel *int - // Optional user info to sync the ghost user while updating membership. - UserInfo *UserInfo - // The user who sent the membership change (user who invited/kicked/banned this user). - // Not yet used. Not applicable if Membership is join or knock. - MemberSender EventSender - // Extra fields to include in the member event. - MemberEventExtra map[string]any - // The expected previous membership. If this doesn't match, the change is ignored. + UserInfo *UserInfo + PrevMembership event.Membership } -type ChatMemberMap map[networkid.UserID]ChatMember - -// Set adds the given entry to this map, overwriting any existing entry with the same Sender field. -func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap { - if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { - return cmm - } - cmm[member.Sender] = member - return cmm -} - -// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. -// It returns true if the entry was added, false otherwise. -func (cmm ChatMemberMap) Add(member ChatMember) bool { - if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { - return false - } - if _, exists := cmm[member.Sender]; exists { - return false - } - cmm[member.Sender] = member - return true -} - type ChatMemberList struct { // Whether this is the full member list. // If true, any extra members not listed here will be removed from the portal. @@ -4019,37 +2521,18 @@ type ChatMemberList struct { // Should the bridge call IsThisUser for every member in the list? // This should be used when SenderLogin can't be filled accurately. CheckAllLogins bool - // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default? - // This is recommended for syncs with non-real-time changes. - // Real-time changes (e.g. a user joining) should not set this flag set. - ExcludeChangesFromTimeline bool - // The total number of members in the chat, regardless of how many of those members are included in MemberMap. + // The total number of members in the chat, regardless of how many of those members are included in Members. TotalMemberCount int // For DM portals, the ID of the recipient user. - // This field is optional and will be automatically filled from MemberMap if there are only 2 entries in the map. + // This field is optional and will be automatically filled from Members if there are only 2 entries in the list. OtherUserID networkid.UserID - // Deprecated: Use MemberMap instead to avoid duplicate entries Members []ChatMember - MemberMap ChatMemberMap PowerLevels *PowerLevelOverrides } -func (cml *ChatMemberList) memberListToMap(ctx context.Context) { - if cml.Members == nil || cml.MemberMap != nil { - return - } - cml.MemberMap = make(map[networkid.UserID]ChatMember, len(cml.Members)) - for _, member := range cml.Members { - if _, alreadyExists := cml.MemberMap[member.Sender]; alreadyExists { - zerolog.Ctx(ctx).Warn().Str("member_id", string(member.Sender)).Msg("Duplicate member in list") - } - cml.MemberMap[member.Sender] = member - } -} - type PowerLevelOverrides struct { Events map[event.Type]int UsersDefault *int @@ -4066,10 +2549,8 @@ type PowerLevelOverrides struct { // Deprecated: renamed to PowerLevelOverrides type PowerLevelChanges = PowerLevelOverrides -func allowChange(newLevel *int, oldLevel, actorLevel int) bool { - return newLevel != nil && - *newLevel <= actorLevel && oldLevel <= actorLevel && - oldLevel != *newLevel +func allowChange(newLevel, oldLevel, actorLevel int) bool { + return newLevel <= actorLevel && oldLevel <= actorLevel } func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevelsEventContent) (changed bool) { @@ -4085,32 +2566,32 @@ func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevel } else { actorLevel = (1 << 31) - 1 } - if allowChange(plc.UsersDefault, content.UsersDefault, actorLevel) { - changed = true + if plc.UsersDefault != nil && allowChange(*plc.UsersDefault, content.UsersDefault, actorLevel) { + changed = content.UsersDefault != *plc.UsersDefault content.UsersDefault = *plc.UsersDefault } - if allowChange(plc.EventsDefault, content.EventsDefault, actorLevel) { - changed = true + if plc.EventsDefault != nil && allowChange(*plc.EventsDefault, content.EventsDefault, actorLevel) { + changed = content.EventsDefault != *plc.EventsDefault content.EventsDefault = *plc.EventsDefault } - if allowChange(plc.StateDefault, content.StateDefault(), actorLevel) { - changed = true + if plc.StateDefault != nil && allowChange(*plc.StateDefault, content.StateDefault(), actorLevel) { + changed = content.StateDefault() != *plc.StateDefault content.StateDefaultPtr = plc.StateDefault } - if allowChange(plc.Invite, content.Invite(), actorLevel) { - changed = true + if plc.Invite != nil && allowChange(*plc.Invite, content.Invite(), actorLevel) { + changed = content.Invite() != *plc.Invite content.InvitePtr = plc.Invite } - if allowChange(plc.Kick, content.Kick(), actorLevel) { - changed = true + if plc.Kick != nil && allowChange(*plc.Kick, content.Kick(), actorLevel) { + changed = content.Kick() != *plc.Kick content.KickPtr = plc.Kick } - if allowChange(plc.Ban, content.Ban(), actorLevel) { - changed = true + if plc.Ban != nil && allowChange(*plc.Ban, content.Ban(), actorLevel) { + changed = content.Ban() != *plc.Ban content.BanPtr = plc.Ban } - if allowChange(plc.Redact, content.Redact(), actorLevel) { - changed = true + if plc.Redact != nil && allowChange(*plc.Redact, content.Redact(), actorLevel) { + changed = content.Redact() != *plc.Redact content.RedactPtr = plc.Redact } if plc.Custom != nil { @@ -4119,10 +2600,6 @@ func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevel return changed } -// DefaultChatName can be used to explicitly clear the name of a room -// and reset it to the default one based on members. -var DefaultChatName = ptr.Ptr("") - type ChatInfo struct { Name *string Topic *string @@ -4135,11 +2612,9 @@ type ChatInfo struct { Disappear *database.DisappearingSetting ParentID *networkid.PortalID - UserLocal *UserLocalPortalInfo - MessageRequest *bool - CanBackfill bool + UserLocal *UserLocalPortalInfo - ExcludeChangesFromTimeline bool + CanBackfill bool ExtraUpdates ExtraUpdater[*Portal] } @@ -4173,36 +2648,26 @@ type UserLocalPortalInfo struct { Tag *event.RoomTag } -func (portal *Portal) updateName( - ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, -) bool { +func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { if portal.Name == name && (portal.NameSet || portal.MXID == "") { return false } portal.Name = name - portal.NameSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil, - ) + portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}) return true } -func (portal *Portal) updateTopic( - ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, -) bool { +func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { return false } portal.Topic = topic - portal.TopicSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil, - ) + portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}) return true } -func (portal *Portal) updateAvatar( - ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, -) bool { - if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") { +func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { + if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") { return false } portal.AvatarID = avatar.ID @@ -4218,15 +2683,13 @@ func (portal *Portal) updateAvatar( portal.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") return true - } else if newHash == portal.AvatarHash && portal.AvatarMXC != "" && portal.AvatarSet { + } else if newHash == portal.AvatarHash && portal.AvatarSet { return true } portal.AvatarMXC = newMXC portal.AvatarHash = newHash } - portal.AvatarSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil, - ) + portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}) return true } @@ -4240,28 +2703,15 @@ func (portal *Portal) GetTopLevelParent() *Portal { return portal.Parent.GetTopLevelParent() } -func (portal *Portal) getBridgeInfoStateKey() string { - if portal.Bridge.Config.NoBridgeInfoStateKey { - return "" - } - idProvider, ok := portal.Bridge.Matrix.(MatrixConnectorWithBridgeIdentifier) - if ok { - return idProvider.GetUniqueBridgeID() - } - return string(portal.BridgeID) -} - func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { bridgeInfo := event.BridgeEventContent{ BridgeBot: portal.Bridge.Bot.GetMXID(), Creator: portal.Bridge.Bot.GetMXID(), Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), Channel: event.BridgeInfoSection{ - ID: string(portal.ID), - DisplayName: portal.Name, - AvatarURL: portal.AvatarMXC, - Receiver: string(portal.Receiver), - MessageRequest: portal.MessageRequest, + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, // TODO external URL? }, BeeperRoomTypeV2: string(portal.RoomType), @@ -4269,10 +2719,6 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM { bridgeInfo.BeeperRoomType = "dm" } - if bridgeInfo.Protocol.ID == "slackgo" { - bridgeInfo.TempSlackRemoteIDMigratedFlag = true - bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true - } parent := portal.GetTopLevelParent() if parent != nil { bridgeInfo.Network = &event.BridgeInfoSection{ @@ -4286,7 +2732,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if ok { filler.FillPortalBridgeInfo(portal, &bridgeInfo) } - return portal.getBridgeInfoStateKey(), bridgeInfo + // TODO use something globally unique instead of bridge ID? + // maybe ask the matrix connector to use serverName+appserviceID+bridgeID + stateKey := string(portal.BridgeID) + return stateKey, bridgeInfo } func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { @@ -4294,54 +2743,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { return } stateKey, bridgeInfo := portal.getBridgeInfo() - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil) - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil) -} - -func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { - if portal.MXID == "" { - return false - } else if !implicit && time.Since(portal.lastCapUpdate) < 24*time.Hour { - return false - } else if portal.CapState.ID != "" && source.ID != portal.CapState.Source && source.ID != portal.Receiver { - // TODO allow capability state source to change if the old user login is removed from the portal - return false - } - caps := source.Client.GetCapabilities(ctx, portal) - capID := caps.GetID() - if capID == portal.CapState.ID { - return false - } - zerolog.Ctx(ctx).Debug(). - Str("user_login_id", string(source.ID)). - Str("old_id", portal.CapState.ID). - Str("new_id", capID). - Msg("Sending new room capability event") - success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil) - if !success { - return false - } - portal.CapState = database.CapabilityState{ - Source: source.ID, - ID: capID, - Flags: portal.CapState.Flags, - } - if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { - zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") - success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) - if !success { - return false - } - portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet - } - portal.lastCapUpdate = time.Now() - if implicit { - err := portal.Save(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal capability state after sending state event") - } - } - return true + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) } func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { @@ -4359,27 +2762,15 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri return } -func (portal *Portal) sendRoomMeta( - ctx context.Context, - sender MatrixAPI, - ts time.Time, - eventType event.Type, - stateKey string, - content any, - excludeFromTimeline bool, - extra map[string]any, -) bool { +func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { if portal.MXID == "" { return false } - if extra == nil { - extra = make(map[string]any) - } - if excludeFromTimeline { - extra["com.beeper.exclude_from_timeline"] = true - } + var extra map[string]any if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) { - extra["fi.mau.implicit_name"] = true + extra = map[string]any{ + "fi.mau.implicit_name": true, + } } _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{ Parsed: content, @@ -4391,62 +2782,16 @@ func (portal *Portal) sendRoomMeta( Msg("Failed to set room metadata") return false } - if eventType == event.StateBeeperDisappearingTimer { - // TODO remove this debug log at some point - zerolog.Ctx(ctx).Debug(). - Any("content", content). - Msg("Sent new disappearing timer event") - } return true } -func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) { - if !portal.Bridge.Config.RevertFailedStateChanges { - return - } - if evt.GetStateKey() != "" && evt.Type != event.StateMember { - return - } - switch evt.Type { - case event.StateRoomName: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil) - case event.StateRoomAvatar: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil) - case event.StateTopic: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil) - case event.StateBeeperDisappearingTimer: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) - case event.StateMember: - var prevContent *event.MemberEventContent - var extra map[string]any - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent = evt.Unsigned.PrevContent.AsMember() - newContent := evt.Content.AsMember() - if prevContent.Membership == newContent.Membership { - return - } - extra = evt.Unsigned.PrevContent.Raw - } else { - prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} - } - if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange { - if extra == nil { - extra = make(map[string]any) - } - extra["com.beeper.member_rollback"] = true - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra) - } - } -} - func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { if members == nil { invite = []id.UserID{source.UserMXID} return } var loginsInPortal []*UserLogin - if members.CheckAllLogins && !portal.Bridge.Config.SplitPortals { + if members.CheckAllLogins { loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { err = fmt.Errorf("failed to get user logins in portal: %w", err) @@ -4454,12 +2799,7 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem } } members.PowerLevels.Apply("", pl) - members.memberListToMap(ctx) - for _, member := range members.MemberMap { - if ctx.Err() != nil { - err = ctx.Err() - return - } + for _, member := range members.Members { if member.Membership != event.MembershipJoin && member.Membership != "" { continue } @@ -4471,10 +2811,7 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem ghost.UpdateInfo(ctx, member.UserInfo) } } - intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) - if err != nil { - return nil, nil, err - } + intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) if extraUserID != "" { invite = append(invite, extraUserID) if member.PowerLevel != nil { @@ -4498,18 +2835,16 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem } func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberList) (changed bool) { - members.memberListToMap(ctx) var expectedUserID networkid.UserID if portal.RoomType != database.RoomTypeDM { // expected user ID is empty } else if members.OtherUserID != "" { expectedUserID = members.OtherUserID - } else if len(members.MemberMap) == 2 && members.IsFull { - vals := maps.Values(members.MemberMap) - if vals[0].IsFromMe && !vals[1].IsFromMe { - expectedUserID = vals[1].Sender - } else if vals[1].IsFromMe && !vals[0].IsFromMe { - expectedUserID = vals[0].Sender + } else if len(members.Members) == 2 && members.IsFull { + if members.Members[0].IsFromMe && !members.Members[1].IsFromMe { + expectedUserID = members.Members[1].Sender + } else if members.Members[1].IsFromMe && !members.Members[0].IsFromMe { + expectedUserID = members.Members[0].Sender } } if portal.OtherUserID != expectedUserID { @@ -4523,50 +2858,10 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } -func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { - switch rule.JoinRule { - case event.JoinRulePublic: - return true - case event.JoinRuleKnockRestricted, event.JoinRuleRestricted: - for _, allow := range rule.Allow { - if allow.Type == "fi.mau.spam_checker" { - return true - } - } - } - return false -} - -func (portal *Portal) roomIsPublic(ctx context.Context) bool { - mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) - if !ok { - return false - } - evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") - if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") - return false - } else if evt == nil { - return false - } - content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent) - if !ok { - return false - } - return looksDirectlyJoinable(content) -} - -func (portal *Portal) syncParticipants( - ctx context.Context, - members *ChatMemberList, - source *UserLogin, - sender MatrixAPI, - ts time.Time, -) error { - members.memberListToMap(ctx) +func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { var loginsInPortal []*UserLogin var err error - if members.CheckAllLogins && !portal.Bridge.Config.SplitPortals { + if members.CheckAllLogins { loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { return fmt.Errorf("failed to get user logins in portal: %w", err) @@ -4586,13 +2881,7 @@ func (portal *Portal) syncParticipants( } delete(currentMembers, portal.Bridge.Bot.GetMXID()) powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) - addExcludeFromTimeline := func(raw map[string]any) { - _, hasKey := raw["com.beeper.exclude_from_timeline"] - if !hasKey && members.ExcludeChangesFromTimeline { - raw["com.beeper.exclude_from_timeline"] = true - } - } - syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool { + syncUser := func(extraUserID id.UserID, member ChatMember, hasIntent bool) bool { if member.Membership == "" { member.Membership = event.MembershipJoin } @@ -4621,74 +2910,58 @@ func (portal *Portal) syncParticipants( Displayname: currentMember.Displayname, AvatarURL: currentMember.AvatarURL, } - wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} - addExcludeFromTimeline(wrappedContent.Raw) + wrappedContent := &event.Content{Parsed: content, Raw: make(map[string]any)} thisEvtSender := sender - if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) { + if member.Membership == event.MembershipJoin { content.Membership = event.MembershipInvite - if intent != nil { + if hasIntent { wrappedContent.Raw["fi.mau.will_auto_accept"] = true } if thisEvtSender.GetMXID() == extraUserID { thisEvtSender = portal.Bridge.Bot } } - addLogContext := func(e *zerolog.Event) *zerolog.Event { - return e.Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)) - } if currentMember != nil && currentMember.Membership == event.MembershipBan && member.Membership != event.MembershipLeave { unbanContent := *content unbanContent.Membership = event.MembershipLeave wrappedUnbanContent := &event.Content{Parsed: &unbanContent} _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedUnbanContent, ts) if err != nil { - addLogContext(log.Err(err)). - Str("new_membership", string(unbanContent.Membership)). + log.Err(err). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). Msg("Failed to unban user to update membership") } else { - addLogContext(log.Trace()). - Str("new_membership", string(unbanContent.Membership)). + log.Trace(). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). Msg("Unbanned user to update membership") - currentMember.Membership = event.MembershipLeave } } - if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID { - _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts) - } else { - _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) - } + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) if err != nil { - addLogContext(log.Err(err)). - Str("new_membership", string(content.Membership)). + log.Err(err). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). Msg("Failed to update user membership") } else { - addLogContext(log.Trace()). - Str("new_membership", string(content.Membership)). - Msg("Updated membership in room") - currentMember.Membership = content.Membership - - if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin { - content.Membership = event.MembershipJoin - wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} - addExcludeFromTimeline(wrappedContent.Raw) - _, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts) - if err != nil { - addLogContext(log.Err(err)). - Str("new_membership", string(content.Membership)). - Msg("Failed to join with intent") - } else { - addLogContext(log.Trace()). - Str("new_membership", string(content.Membership)). - Msg("Joined room with intent") - } - } + log.Trace(). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). + Msg("Updating membership in room") } return true } syncIntent := func(intent MatrixAPI, member ChatMember) { - if !syncUser(intent.GetMXID(), member, intent) { + if !syncUser(intent.GetMXID(), member, true) { return } if member.Membership == event.MembershipJoin || member.Membership == "" { @@ -4700,10 +2973,7 @@ func (portal *Portal) syncParticipants( } } } - for _, member := range members.MemberMap { - if ctx.Err() != nil { - return ctx.Err() - } + for _, member := range members.Members { if member.Sender != "" && member.UserInfo != nil { ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) if err != nil { @@ -4712,15 +2982,12 @@ func (portal *Portal) syncParticipants( ghost.UpdateInfo(ctx, member.UserInfo) } } - intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) - if err != nil { - return err - } + intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) if intent != nil { syncIntent(intent, member) } if extraUserID != "" { - syncUser(extraUserID, member, nil) + syncUser(extraUserID, member, false) } } if powerChanged { @@ -4735,7 +3002,7 @@ func (portal *Portal) syncParticipants( if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } - if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) { + if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil { continue } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ @@ -4745,9 +3012,6 @@ func (portal *Portal) syncParticipants( Displayname: memberEvt.Displayname, Reason: "User is not in remote chat", }, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline, - }, }, time.Now()) if err != nil { zerolog.Ctx(ctx).Err(err). @@ -4779,17 +3043,13 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo if info == nil { return } - if info.MutedUntil != nil && (didJustCreate || !portal.Bridge.Config.MuteOnlyOnCreate) && (!didJustCreate || info.MutedUntil.After(time.Now())) { + if info.MutedUntil != nil && (didJustCreate || !portal.Bridge.Config.MuteOnlyOnCreate) { err := dp.MuteRoom(ctx, portal.MXID, *info.MutedUntil) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") } } - if info.Tag != nil && - len(portal.Bridge.Config.OnlyBridgeTags) > 0 && - (*info.Tag == "" || slices.Contains(portal.Bridge.Config.OnlyBridgeTags, *info.Tag)) && - (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && - (!didJustCreate || *info.Tag != "") { + if info.Tag != nil && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) { err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") @@ -4797,105 +3057,49 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo } } -func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.MessageEventContent { - formattedDuration := exfmt.DurationCustom(expiration, nil, exfmt.Day, time.Hour, time.Minute, time.Second) - content := &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), - Mentions: &event.Mentions{}, +func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { + if setting.Timer == 0 { + setting.Type = "" } - if expiration == 0 { - if implicit { - content.Body = "Automatically turned off disappearing messages because incoming message is not disappearing" - } else { - content.Body = "Turned off disappearing messages" - } - } else if implicit { - content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", formattedDuration) - } - return content -} - -type UpdateDisappearingSettingOpts struct { - Sender MatrixAPI - Timestamp time.Time - Implicit bool - Save bool - SendNotice bool - - ExcludeFromTimeline bool -} - -func (portal *Portal) UpdateDisappearingSetting( - ctx context.Context, - setting database.DisappearingSetting, - opts UpdateDisappearingSettingOpts, -) bool { - setting = setting.Normalize() if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false } portal.Disappear.Type = setting.Type portal.Disappear.Timer = setting.Timer - if opts.Save { + if save { err := portal.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") } } - if portal.MXID == "" { - return true + content := &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("Disappearing messages set to %s", exfmt.Duration(setting.Timer)), } - - if opts.Sender == nil { - opts.Sender = portal.Bridge.Bot + if implicit { + content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", exfmt.Duration(setting.Timer)) + } else if setting.Timer == 0 { + content.Body = "Disappearing messages disabled" } - if opts.Timestamp.IsZero() { - opts.Timestamp = time.Now() + if sender == nil { + sender = portal.Bridge.Bot } - portal.sendRoomMeta( - ctx, - opts.Sender, - opts.Timestamp, - event.StateBeeperDisappearingTimer, - "", - setting.ToEventContent(), - opts.ExcludeFromTimeline, - nil, - ) - - if !opts.SendNotice { - return true - } - content := DisappearingMessageNotice(setting.Timer, opts.Implicit) - _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, - Raw: map[string]any{ - "com.beeper.action_message": map[string]any{ - "type": "disappearing_timer", - "timer": setting.Timer.Milliseconds(), - "timer_type": setting.Type, - "implicit": opts.Implicit, - }, - }, - }, &MatrixSendExtra{Timestamp: opts.Timestamp}) + }, &MatrixSendExtra{Timestamp: ts}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") } else { zerolog.Ctx(ctx).Debug(). Dur("new_timer", portal.Disappear.Timer). - Bool("implicit", opts.Implicit). + Bool("implicit", implicit). Msg("Sent disappearing messages notice") } return true } -func (portal *Portal) updateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { - newParent := networkid.PortalKey{ID: newParentID} - if portal.Bridge.Config.SplitPortals { - newParent.Receiver = portal.Receiver - } - if portal.ParentKey == newParent { +func (portal *Portal) updateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { + if portal.ParentID == newParent { return false } var err error @@ -4905,10 +3109,10 @@ func (portal *Portal) updateParent(ctx context.Context, newParentID networkid.Po zerolog.Ctx(ctx).Err(err).Stringer("old_space_mxid", portal.Parent.MXID).Msg("Failed to remove portal from old space") } } - portal.ParentKey = newParent + portal.ParentID = newParent portal.InSpace = false - if newParent.ID != "" { - portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, newParent) + if newParent != "" { + portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, networkid.PortalKey{ID: newParent}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get new parent portal") } @@ -4946,51 +3150,38 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch return } } - changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed + changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed changed = portal.updateAvatar(ctx, &Avatar{ ID: ghost.AvatarID, MXC: ghost.AvatarMXC, Hash: ghost.AvatarHash, Remove: ghost.AvatarID == "", - }, nil, time.Time{}, false) || changed + }, nil, time.Time{}) || changed return } func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false - if info.Name == DefaultChatName { - if portal.NameIsCustom { - portal.NameIsCustom = false - changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed - } - } else if info.Name != nil { + if info.Name != nil { portal.NameIsCustom = true - changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateName(ctx, *info.Name, sender, ts) || changed } if info.Topic != nil { - changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed } if info.Avatar != nil { portal.NameIsCustom = true - changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed } if info.Disappear != nil { - changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{ - Sender: sender, - Timestamp: ts, - Implicit: false, - Save: false, - - SendNotice: !info.ExcludeChangesFromTimeline, - ExcludeFromTimeline: info.ExcludeChangesFromTimeline, - }) || changed + changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } if info.ParentID != nil { changed = portal.updateParent(ctx, *info.ParentID, source) || changed } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? - portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil) + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { @@ -5003,10 +3194,6 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } - if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest { - changed = true - portal.MessageRequest = *info.MessageRequest - } if info.Members != nil && portal.MXID != "" && source != nil { err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { @@ -5020,9 +3207,8 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if source != nil { source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source, false) - changed = portal.UpdateCapabilities(ctx, source, false) || changed } - if info.CanBackfill && source != nil && portal.MXID != "" { + if info.CanBackfill && source != nil { err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure backfill queue task exists") @@ -5048,12 +3234,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } return nil } - if portal.deleted.IsSet() { - return ErrPortalIsDeleted - } waiter := make(chan struct{}) closed := false - evt := &portalCreateEvent{ + portal.events <- &portalCreateEvent{ ctx: ctx, source: source, info: info, @@ -5065,15 +3248,6 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } }, } - if PortalEventBuffer == 0 { - go portal.queueEvent(ctx, evt) - } else { - select { - case portal.events <- evt: - case <-portal.deleted.GetChan(): - return ErrPortalIsDeleted - } - } select { case <-ctx.Done(): return ctx.Err() @@ -5083,11 +3257,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { - cancellableCtx, cancel := context.WithCancel(ctx) - defer cancel() - portal.cancelRoomCreate.CompareAndSwap(nil, &cancel) portal.roomCreateLock.Lock() - portal.cancelRoomCreate.Store(&cancel) defer portal.roomCreateLock.Unlock() if portal.MXID != "" { if source != nil { @@ -5098,7 +3268,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log := zerolog.Ctx(ctx).With(). Str("action", "create matrix room"). Logger() - cancellableCtx = log.WithContext(cancellableCtx) ctx = log.WithContext(ctx) log.Info().Msg("Creating Matrix room") @@ -5107,17 +3276,14 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo if info != nil { log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") } - info, err = source.Client.GetChatInfo(cancellableCtx, portal) + info, err = source.Client.GetChatInfo(ctx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") return err } } - portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{}) - if cancellableCtx.Err() != nil { - return cancellableCtx.Err() - } + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) powerLevels := &event.PowerLevelsEventContent{ Events: map[string]int{ @@ -5129,7 +3295,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.Bridge.Bot.GetMXID(): 9001, }, } - initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels) + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -5138,12 +3304,14 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req := mautrix.ReqCreateRoom{ Visibility: "private", + Name: portal.Name, + Topic: portal.Topic, CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, - BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), + BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { @@ -5156,11 +3324,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req.CreationContent["type"] = event.RoomTypeSpace } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal) - portal.CapState = database.CapabilityState{ - Source: source.ID, - ID: roomFeatures.GetID(), - } req.InitialState = append(req.InitialState, &event.Event{ Type: event.StateElementFunctionalMembers, @@ -5175,51 +3338,19 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo StateKey: &bridgeInfoStateKey, Type: event.StateBridge, Content: event.Content{Parsed: &bridgeInfo}, - }, &event.Event{ - StateKey: &bridgeInfoStateKey, - Type: event.StateBeeperRoomFeatures, - Content: event.Content{Parsed: roomFeatures}, - }, &event.Event{ - Type: event.StateTopic, - Content: event.Content{ - Parsed: &event.TopicEventContent{Topic: portal.Topic}, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, }) - if roomFeatures.DisappearingTimer != nil { + if req.Topic == "" { + // Add explicit topic event if topic is empty to ensure the event is set. + // This ensures that there won't be an extra event later if PUT /state/... is called. req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateBeeperDisappearingTimer, - Content: event.Content{ - Parsed: portal.Disappear.ToEventContent(), - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, - }) - portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet - } - if portal.Name != "" { - req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateRoomName, - Content: event.Content{ - Parsed: &event.RoomNameEventContent{Name: portal.Name}, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, + Type: event.StateTopic, + Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, }) } if portal.AvatarMXC != "" { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateRoomAvatar, - Content: event.Content{ - Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, + Type: event.StateRoomAvatar, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, }) } if portal.Parent != nil && portal.Parent.MXID != "" { @@ -5238,9 +3369,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Content: event.Content{Parsed: info.JoinRule}, }) } - if cancellableCtx.Err() != nil { - return cancellableCtx.Err() - } roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) if err != nil { log.Err(err).Msg("Failed to create Matrix room") @@ -5251,7 +3379,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.TopicSet = true portal.NameSet = true portal.MXID = roomID - portal.RoomCreated.Set() portal.Bridge.cacheLock.Lock() portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() @@ -5261,7 +3388,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log.Err(err).Msg("Failed to save portal to database after creating Matrix room") return err } - if info.CanBackfill && portal.RoomType != database.RoomTypeSpace { + if info.CanBackfill { err = portal.Bridge.DB.BackfillTask.Upsert(ctx, &database.BackfillTask{ PortalKey: portal.PortalKey, UserLoginID: source.ID, @@ -5272,13 +3399,12 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } portal.Bridge.WakeupBackfillQueue() } - withoutCancelCtx := zerolog.Ctx(ctx).WithContext(portal.Bridge.BackgroundCtx) if portal.Parent != nil { if portal.Parent.MXID != "" { portal.addToParentSpaceAndSave(ctx, true) } else { log.Info().Msg("Parent portal doesn't exist, creating in background") - go portal.createParentAndAddToSpace(withoutCancelCtx, source) + go portal.createParentAndAddToSpace(ctx, source) } } portal.updateUserLocalInfo(ctx, info.UserLocal, source, true) @@ -5298,34 +3424,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - portal.addToUserSpaces(ctx) - if info.CanBackfill && - portal.Bridge.Config.Backfill.Enabled && - portal.RoomType != database.RoomTypeSpace && - !portal.Bridge.Background { - portal.doForwardBackfill(ctx, source, nil, backfillBundle) - } - return nil -} - -func (portal *Portal) addToUserSpaces(ctx context.Context) { - if portal.Parent != nil { - return - } - log := zerolog.Ctx(ctx) - withoutCancelCtx := log.WithContext(portal.Bridge.BackgroundCtx) - if portal.Receiver != "" { - login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) - if login != nil { - up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user portal to add portal to spaces") - } else { - login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) - } - } - } else { + if portal.Parent == nil { userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") @@ -5334,19 +3433,18 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) if login != nil { login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + go login.tryAddPortalToSpace(ctx, portal, up.CopyWithoutValues()) } } } } + portal.doForwardBackfill(ctx, source, nil, backfillBundle) + return nil } func (portal *Portal) Delete(ctx context.Context) error { - if portal.deleted.IsSet() { - return nil - } portal.removeInPortalCache(ctx) - err := portal.safeDBDelete(ctx) + err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) if err != nil { return err } @@ -5356,21 +3454,11 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } -func (portal *Portal) safeDBDelete(ctx context.Context) error { - err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey) - if err != nil { - return fmt.Errorf("failed to delete messages in portal: %w", err) - } - // TODO delete child portals? - return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) -} - func (portal *Portal) RemoveMXID(ctx context.Context) error { if portal.MXID == "" { return nil } portal.MXID = "" - portal.RoomCreated.Clear() err := portal.Save(ctx) if err != nil { return err @@ -5403,10 +3491,8 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) { } func (portal *Portal) unlockedDelete(ctx context.Context) error { - if portal.deleted.IsSet() { - return nil - } - err := portal.safeDBDelete(ctx) + // TODO delete child portals? + err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) if err != nil { return err } @@ -5415,18 +3501,10 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error { } func (portal *Portal) unlockedDeleteCache() { - if portal.deleted.IsSet() { - return - } delete(portal.Bridge.portalsByKey, portal.PortalKey) if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) } - portal.deleted.Set() - if portal.events != nil { - // TODO there's a small risk of this racing with a queueEvent call - close(portal.events) - } } func (portal *Portal) Save(ctx context.Context) error { @@ -5434,9 +3512,6 @@ func (portal *Portal) Save(ctx context.Context) error { } func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { - if portal.Receiver != "" && relay.ID != portal.Receiver { - return fmt.Errorf("can't set non-receiver login as relay") - } portal.Relay = relay if relay == nil { portal.RelayLoginID = "" @@ -5449,17 +3524,3 @@ func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { } return nil } - -func (portal *Portal) PerMessageProfileForSender(ctx context.Context, sender networkid.UserID) (profile event.BeeperPerMessageProfile, err error) { - var ghost *Ghost - ghost, err = portal.Bridge.GetGhostByID(ctx, sender) - if err != nil { - return - } - profile.ID = string(ghost.Intent.GetMXID()) - profile.Displayname = ghost.Name - if ghost.AvatarMXC != "" { - profile.AvatarURL = &ghost.AvatarMXC - } - return -} diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 879f07ae..1bff29c1 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -61,26 +61,16 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, return } else if len(resp.Messages) == 0 { log.Debug().Msg("No messages to backfill") - if resp.CompleteCallback != nil { - resp.CompleteCallback() - } return } - log.Debug(). - Int("message_count", len(resp.Messages)). - Bool("mark_read", resp.MarkRead). - Bool("aggressive_deduplication", resp.AggressiveDeduplication). - Msg("Fetched messages for forward backfill, deduplicating before sending") + // TODO check pending messages // TODO mark backfill queue task as done if last message is nil (-> room was empty) and HasMore is false? - resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, lastMessage) + resp.Messages = cutoffMessages(&log, resp.Messages, true, lastMessage) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") - if resp.CompleteCallback != nil { - resp.CompleteCallback() - } return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, resp.CompleteCallback) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false) } func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin, task *database.BackfillTask) error { @@ -134,19 +124,13 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin } else { log.Warn().Msg("No messages to backfill, but HasMore is true") } - if resp.CompleteCallback != nil { - resp.CompleteCallback() - } return nil } - resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, false, firstMessage) + resp.Messages = cutoffMessages(log, resp.Messages, false, firstMessage) if len(resp.Messages) == 0 { - if resp.CompleteCallback != nil { - resp.CompleteCallback() - } return fmt.Errorf("no messages left to backfill after cutting off too new messages") } - portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, resp.CompleteCallback) + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false) if len(resp.Messages) > 0 { task.OldestMessageID = resp.Messages[0].ID } @@ -172,12 +156,9 @@ func (portal *Portal) fetchThreadBackfill(ctx context.Context, source *UserLogin log.Debug().Msg("No messages to backfill") return nil } - resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, anchor) + resp.Messages = cutoffMessages(log, resp.Messages, true, anchor) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") - if resp.CompleteCallback != nil { - resp.CompleteCallback() - } return nil } return resp @@ -194,17 +175,14 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to get last thread message") return - } else if anchorMessage == nil { - log.Warn().Msg("No messages found in thread?") - return } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true, resp.CompleteCallback) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true) } } -func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { +func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bool, lastMessage *database.Message) []*BackfillMessage { if lastMessage == nil { return messages } @@ -218,7 +196,7 @@ func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMe } } if cutoff != -1 { - zerolog.Ctx(ctx).Debug(). + log.Debug(). Int("cutoff_count", cutoff+1). Int("total_count", len(messages)). Time("last_bridged_ts", lastMessage.Timestamp). @@ -235,7 +213,7 @@ func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMe } } if cutoff != -1 { - zerolog.Ctx(ctx).Debug(). + log.Debug(). Int("cutoff_count", len(messages)-cutoff). Int("total_count", len(messages)). Time("oldest_bridged_ts", lastMessage.Timestamp). @@ -243,47 +221,10 @@ func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMe messages = messages[:cutoff] } } - if aggressiveDedup { - filteredMessages := messages[:0] - for _, msg := range messages { - existingMsg, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, msg.ID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Str("message_id", string(msg.ID)).Msg("Failed to check for existing message") - } else if existingMsg != nil { - zerolog.Ctx(ctx).Err(err). - Str("message_id", string(msg.ID)). - Time("message_ts", msg.Timestamp). - Str("message_sender", string(msg.Sender.Sender)). - Msg("Ignoring duplicate message in backfill") - continue - } - if forward && msg.TxnID != "" { - wasPending, _ := portal.checkPendingMessage(ctx, msg) - if wasPending { - zerolog.Ctx(ctx).Err(err). - Str("transaction_id", string(msg.TxnID)). - Str("message_id", string(msg.ID)). - Time("message_ts", msg.Timestamp). - Msg("Found pending message in backfill") - continue - } - } - filteredMessages = append(filteredMessages, msg) - } - messages = filteredMessages - } return messages } -func (portal *Portal) sendBackfill( - ctx context.Context, - source *UserLogin, - messages []*BackfillMessage, - forceForward, - markRead, - inThread bool, - done func(), -) { +func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending unreadThreshold := time.Duration(portal.Bridge.Config.Backfill.UnreadHoursThreshold) * time.Hour forceMarkRead := unreadThreshold > 0 && time.Since(messages[len(messages)-1].Timestamp) > unreadThreshold @@ -298,9 +239,6 @@ func (portal *Portal) sendBackfill( } else { portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) } - if done != nil { - done() - } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") if !canBatchSend && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { for _, msg := range messages { @@ -326,13 +264,8 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if len(msg.Parts) == 0 { return } - intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - if !ok { - return - } - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( - ctx, msg.ID, msg.ConvertedMessage, true, - ) + intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) if threadRoot != nil && out.PrevThreadEvents[*msg.ThreadRoot] != "" { prevThreadEvent.MXID = out.PrevThreadEvents[*msg.ThreadRoot] } @@ -341,27 +274,8 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin var firstPart *database.Message for i, part := range msg.Parts { partIDs = append(partIDs, part.ID) - portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) - part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent() + portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) - dbMessage := &database.Message{ - ID: msg.ID, - PartID: part.ID, - MXID: evtID, - Room: portal.PortalKey, - SenderID: msg.Sender.Sender, - SenderMXID: intent.GetMXID(), - Timestamp: msg.Timestamp, - ThreadRoot: ptr.Val(msg.ThreadRoot), - ReplyTo: ptr.Val(msg.ReplyTo), - Metadata: part.DBMetadata, - IsDoublePuppeted: intent.IsDoublePuppet(), - } - if part.DontBridge { - dbMessage.SetFakeMXID() - out.DBMessages = append(out.DBMessages, dbMessage) - continue - } out.Events = append(out.Events, &event.Event{ Sender: intent.GetMXID(), Type: part.Type, @@ -373,6 +287,18 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin Raw: part.Extra, }, }) + dbMessage := &database.Message{ + ID: msg.ID, + PartID: part.ID, + MXID: evtID, + Room: portal.PortalKey, + SenderID: msg.Sender.Sender, + SenderMXID: intent.GetMXID(), + Timestamp: msg.Timestamp, + ThreadRoot: ptr.Val(msg.ThreadRoot), + ReplyTo: ptr.Val(msg.ReplyTo), + Metadata: part.DBMetadata, + } if firstPart == nil { firstPart = dbMessage } @@ -383,34 +309,26 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin prevThreadEvent.MXID = evtID out.PrevThreadEvents[*msg.ThreadRoot] = evtID } - if msg.Disappear.Type != event.DisappearingTypeNone { - if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + if msg.Disappear.Type != database.DisappearingTypeNone { + if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) } out.Disappear = append(out.Disappear, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: evtID, - Timestamp: msg.Timestamp, DisappearingSetting: msg.Disappear, }) } } slices.Sort(partIDs) for _, reaction := range msg.Reactions { - if reaction == nil { - continue - } - reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) - if !ok { - continue - } + reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) if reaction.TargetPart == nil { reaction.TargetPart = &partIDs[0] } if reaction.Timestamp.IsZero() { reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) } - //lint:ignore SA4006 it's a todo targetPart, ok := partMap[*reaction.TargetPart] if !ok { // TODO warning log and/or skip reaction? @@ -530,11 +448,8 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { var lastPart id.EventID for _, msg := range messages { - intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - if !ok { - continue - } - dbMessages, _ := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { + intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { return z. Str("message_id", string(msg.ID)). Any("sender_id", msg.Sender). @@ -543,10 +458,7 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, if len(dbMessages) > 0 { lastPart = dbMessages[len(dbMessages)-1].MXID for _, reaction := range msg.Reactions { - reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) - if !ok { - continue - } + reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) targetPart := dbMessages[0] if reaction.TargetPart != nil { targetPartIdx := slices.IndexFunc(dbMessages, func(dbMsg *database.Message) bool { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 4c7e2447..1ee793a9 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -29,32 +29,20 @@ func (portal *PortalInternals) UpdateLogger() { (*Portal)(portal).updateLogger() } -func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { - return (*Portal)(portal).queueEvent(ctx, evt) +func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) { + (*Portal)(portal).queueEvent(ctx, evt) } func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { - return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt) +func (portal *PortalInternals) HandleCreateEvent(evt *portalCreateEvent) { + (*Portal)(portal).handleCreateEvent(evt) } -func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { - return (*Portal)(portal).getEventCtxWithLog(rawEvt, idx) -} - -func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(EventHandlingResult)) { - (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) -} - -func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error { - return (*Portal)(portal).unwrapBeeperSendState(ctx, evt) -} - -func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { - (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) +func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) { + (*Portal)(portal).sendSuccessStatus(ctx, evt) } func (portal *PortalInternals) SendErrorStatus(ctx context.Context, evt *event.Event, err error) { @@ -65,24 +53,20 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { - return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest) +func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) { + (*Portal)(portal).handleMatrixEvent(sender, evt) } -func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixReceipts(ctx, evt) +func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) { + (*Portal)(portal).handleMatrixReceipts(ctx, evt) } func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) } -func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) { - (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) -} - -func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixTyping(ctx, evt) +func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) { + (*Portal)(portal).handleMatrixTyping(ctx, evt) } func (portal *PortalInternals) SendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { @@ -93,83 +77,55 @@ func (portal *PortalInternals) PeriodicTypingUpdater() { (*Portal)(portal).periodicTypingUpdater() } -func (portal *PortalInternals) CheckMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { - return (*Portal)(portal).checkMessageContentCaps(caps, content) +func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { + return (*Portal)(portal).checkMessageContentCaps(ctx, caps, content, evt) } -func (portal *PortalInternals) ParseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { - return (*Portal)(portal).parseInputTransactionID(origSender, evt) +func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { + (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) } -func (portal *PortalInternals) PendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { - (*Portal)(portal).pendingMessageTimeoutLoop(ctx, cfg) -} - -func (portal *PortalInternals) CheckPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { - (*Portal)(portal).checkPendingMessages(ctx, cfg) -} - -func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) EventHandlingResult { - return (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) -} - -func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) +func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { + (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) } func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { return (*Portal)(portal).getTargetUser(ctx, userID) } -func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { - return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest) +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { - return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest) +func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixTombstone(ctx, evt) +func (portal *PortalInternals) HandleRemoteEvent(source *UserLogin, evt RemoteEvent) { + (*Portal)(portal).handleRemoteEvent(source, evt) } -func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) { - (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser) -} - -func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) -} - -func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { - return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) -} - -func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) { - (*Portal)(portal).ensureFunctionalMember(ctx, ghost) -} - -func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { +func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) } -func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsgID networkid.MessageID, currentMsg *ConvertedMessage, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { - return (*Portal)(portal).getRelationMeta(ctx, currentMsgID, currentMsg, isBatchSend) +func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { + return (*Portal)(portal).getRelationMeta(ctx, currentMsg, replyToPtr, threadRootPtr, isBatchSend) } -func (portal *PortalInternals) ApplyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { - (*Portal)(portal).applyRelationMeta(ctx, content, replyTo, threadRoot, prevThreadEvent) +func (portal *PortalInternals) ApplyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + (*Portal)(portal).applyRelationMeta(content, replyTo, threadRoot, prevThreadEvent) } -func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) ([]*database.Message, EventHandlingResult) { +func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, streamOrder, logContext) } @@ -177,24 +133,24 @@ func (portal *PortalInternals) CheckPendingMessage(ctx context.Context, evt Remo return (*Portal)(portal).checkPendingMessage(ctx, evt) } -func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { +func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { return (*Portal)(portal).handleRemoteUpsert(ctx, source, evt, existing) } -func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { - return (*Portal)(portal).handleRemoteMessage(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { + (*Portal)(portal).handleRemoteMessage(ctx, source, evt) } func (portal *PortalInternals) SendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { (*Portal)(portal).sendRemoteErrorNotice(ctx, intent, err, ts, evtTypeName) } -func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { - return (*Portal)(portal).handleRemoteEdit(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { + (*Portal)(portal).handleRemoteEdit(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) EventHandlingResult { - return (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) +func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) { + (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) } func (portal *PortalInternals) GetTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -205,84 +161,76 @@ func (portal *PortalInternals) GetTargetReaction(ctx context.Context, evt Remote return (*Portal)(portal).getTargetReaction(ctx, evt) } -func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { - return (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { + (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { - return (*Portal)(portal).handleRemoteReaction(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { + (*Portal)(portal).handleRemoteReaction(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) EventHandlingResult { - return (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) +func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) { + (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) } func (portal *PortalInternals) GetIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { return (*Portal)(portal).getIntentForMXID(ctx, userID) } -func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { - return (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { + (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { - return (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { + (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) } -func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { - return (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) +func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { + (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) } -func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { - return (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { + (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { - return (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { + (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { - return (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { + (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { - return (*Portal)(portal).handleRemoteTyping(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { + (*Portal)(portal).handleRemoteTyping(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { - return (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { + (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { - return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { + (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) } -func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { - return (*Portal)(portal).findOtherLogins(ctx, source) +func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { + (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { - return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { + (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) } -func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { - return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) +func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateName(ctx, name, sender, ts) } -func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { - return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline) +func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateTopic(ctx, topic, sender, ts) } -func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { - return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline) -} - -func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { - return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline) -} - -func (portal *PortalInternals) GetBridgeInfoStateKey() string { - return (*Portal)(portal).getBridgeInfoStateKey() +func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts) } func (portal *PortalInternals) GetBridgeInfo() (string, event.BridgeEventContent) { @@ -293,12 +241,8 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) } -func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool { - return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra) -} - -func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) { - (*Portal)(portal).revertRoomMeta(ctx, evt) +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content) } func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { @@ -309,10 +253,6 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha return (*Portal)(portal).updateOtherUser(ctx, members) } -func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool { - return (*Portal)(portal).roomIsPublic(ctx) -} - func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) } @@ -321,8 +261,8 @@ func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *Us (*Portal)(portal).updateUserLocalInfo(ctx, info, source, didJustCreate) } -func (portal *PortalInternals) UpdateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { - return (*Portal)(portal).updateParent(ctx, newParentID, source) +func (portal *PortalInternals) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { + return (*Portal)(portal).updateParent(ctx, newParent, source) } func (portal *PortalInternals) LockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { @@ -333,10 +273,6 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) } -func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) { - (*Portal)(portal).addToUserSpaces(ctx) -} - func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) { (*Portal)(portal).removeInPortalCache(ctx) } @@ -361,12 +297,8 @@ func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *Use (*Portal)(portal).doThreadBackfill(ctx, source, threadID) } -func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { - return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage) -} - -func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool, done func()) { - (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread, done) +func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { + (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) } func (portal *PortalInternals) CompileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) { @@ -400,3 +332,7 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) } + +func (portal *PortalInternals) SetMXIDToExistingRoom(roomID id.RoomID) bool { + return (*Portal)(portal).setMXIDToExistingRoom(roomID) +} diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index c976d97c..a25fe820 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -32,40 +32,21 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta if source == target { return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") } - log := zerolog.Ctx(ctx).With(). - Str("action", "re-id portal"). - Stringer("source_portal_key", source). - Stringer("target_portal_key", target). - Logger() - ctx = log.WithContext(ctx) + log := zerolog.Ctx(ctx) + log.Debug().Msg("Re-ID'ing portal") defer func() { log.Debug().Msg("Finished handling portal re-ID") }() - acquireCacheLock := func() { - if !br.cacheLock.TryLock() { - log.Debug().Msg("Waiting for global cache lock") - br.cacheLock.Lock() - log.Debug().Msg("Acquired global cache lock after waiting") - } else { - log.Trace().Msg("Acquired global cache lock without waiting") - } - } - log.Debug().Msg("Re-ID'ing portal") - sourcePortal, err := br.GetExistingPortalByKey(ctx, source) + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { log.Debug().Msg("Source portal not found, re-ID is no-op") return ReIDResultNoOp, nil, nil } - if !sourcePortal.roomCreateLock.TryLock() { - if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { - (*cancelCreate)() - } - log.Debug().Msg("Waiting for source portal room creation lock") - sourcePortal.roomCreateLock.Lock() - log.Debug().Msg("Acquired source portal room creation lock after waiting") - } + sourcePortal.roomCreateLock.Lock() defer sourcePortal.roomCreateLock.Unlock() if sourcePortal.MXID == "" { log.Info().Msg("Source portal doesn't have Matrix room, deleting row") @@ -78,37 +59,22 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) - - acquireCacheLock() targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { - br.cacheLock.Unlock() return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } if targetPortal == nil { log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") err = sourcePortal.unlockedReID(ctx, target) - br.cacheLock.Unlock() if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) } return ReIDResultSourceReIDd, sourcePortal, nil } - br.cacheLock.Unlock() - - if !targetPortal.roomCreateLock.TryLock() { - if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { - (*cancelCreate)() - } - log.Debug().Msg("Waiting for target portal room creation lock") - targetPortal.roomCreateLock.Lock() - log.Debug().Msg("Acquired target portal room creation lock after waiting") - } + targetPortal.roomCreateLock.Lock() defer targetPortal.roomCreateLock.Unlock() if targetPortal.MXID == "" { log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") - acquireCacheLock() - defer br.cacheLock.Unlock() err = targetPortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) @@ -123,9 +89,6 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return c.Stringer("target_portal_mxid", targetPortal.MXID) }) log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") - sourcePortal.removeInPortalCache(ctx) - acquireCacheLock() - defer br.cacheLock.Unlock() err = sourcePortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) @@ -133,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta go func() { _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ Parsed: &event.TombstoneEventContent{ - Body: "This room has been merged", + Body: fmt.Sprintf("This room has been merged"), ReplacementRoom: targetPortal.MXID, }, }, time.Now()) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go deleted file mode 100644 index 72bacaff..00000000 --- a/bridgev2/provisionutil/creategroup.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package provisionutil - -import ( - "context" - - "github.com/rs/zerolog" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type RespCreateGroup struct { - ID networkid.PortalID `json:"id"` - MXID id.RoomID `json:"mxid"` - Portal *bridgev2.Portal `json:"-"` - - FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` -} - -func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { - api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) - } - zerolog.Ctx(ctx).Debug(). - Any("create_params", params). - Msg("Creating group chat on remote network") - caps := login.Bridge.Network.GetCapabilities() - typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] - if !validType { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type)) - } - if len(params.Participants) < typeSpec.Participants.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) - } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength)) - } - userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - for i, participant := range params.Participants { - parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant)) - if ok { - participant = parsedParticipant - params.Participants[i] = participant - } - if !typeSpec.Participants.SkipIdentifierValidation { - if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) - } - } - if api.IsThisUser(ctx, participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) - } - } - if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) - } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) - } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) - } - if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required")) - } - if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) - } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) - } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) - } - if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required")) - } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer")) - } - if params.Username == "" && typeSpec.Username.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) - } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) - } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) - } - if params.Parent == nil && typeSpec.Parent.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required")) - } - resp, err := api.CreateGroup(ctx, params) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create group") - return nil, err - } - if resp.PortalKey.IsEmpty() { - return nil, ErrNoPortalKey - } - zerolog.Ctx(ctx).Debug(). - Object("portal_key", resp.PortalKey). - Msg("Successfully created group on remote network") - if resp.Portal == nil { - resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) - } - } - if resp.Portal.MXID == "" { - err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) - } - } - for key, fp := range resp.FailedParticipants { - if fp.InviteEventType == "" { - fp.InviteEventType = event.EventMessage.Type - } - if fp.UserMXID == "" { - ghost, err := login.Bridge.GetGhostByID(ctx, key) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant") - } else if ghost != nil { - fp.UserMXID = ghost.Intent.GetMXID() - } - } - if fp.DMRoomMXID == "" { - portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant") - } else if portal != nil { - fp.DMRoomMXID = portal.MXID - } - } - } - return &RespCreateGroup{ - ID: resp.Portal.ID, - MXID: resp.Portal.MXID, - Portal: resp.Portal, - - FailedParticipants: resp.FailedParticipants, - }, nil -} diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go deleted file mode 100644 index ce163e67..00000000 --- a/bridgev2/provisionutil/listcontacts.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package provisionutil - -import ( - "context" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2" -) - -type RespGetContactList struct { - Contacts []*RespResolveIdentifier `json:"contacts"` -} - -type RespSearchUsers struct { - Results []*RespResolveIdentifier `json:"results"` -} - -func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) { - api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts")) - } - resp, err := api.GetContactList(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") - return nil, err - } - return &RespGetContactList{ - Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false), - }, nil -} - -func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) { - api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users")) - } - resp, err := api.SearchUsers(ctx, query) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") - return nil, err - } - return &RespSearchUsers{ - Results: processResolveIdentifiers(ctx, login.Bridge, resp, true), - }, nil -} - -func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) { - apiResp = make([]*RespResolveIdentifier, len(resp)) - for i, contact := range resp { - apiContact := &RespResolveIdentifier{ - ID: contact.UserID, - } - apiResp[i] = apiContact - if contact.UserInfo != nil { - if contact.UserInfo.Name != nil { - apiContact.Name = *contact.UserInfo.Name - } - if contact.UserInfo.Identifiers != nil { - apiContact.Identifiers = contact.UserInfo.Identifiers - } - } - if contact.Ghost != nil { - if syncInfo && contact.UserInfo != nil { - contact.Ghost.UpdateInfo(ctx, contact.UserInfo) - } - if contact.Ghost.Name != "" { - apiContact.Name = contact.Ghost.Name - } - if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { - apiContact.Identifiers = contact.Ghost.Identifiers - } - apiContact.AvatarURL = contact.Ghost.AvatarMXC - apiContact.MXID = contact.Ghost.Intent.GetMXID() - } - if contact.Chat != nil { - if contact.Chat.Portal == nil { - var err error - contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - } - } - if contact.Chat.Portal != nil { - apiContact.DMRoomID = contact.Chat.Portal.MXID - } - } - } - return -} diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go deleted file mode 100644 index cfc388d0..00000000 --- a/bridgev2/provisionutil/resolveidentifier.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package provisionutil - -import ( - "context" - "errors" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" -) - -type RespResolveIdentifier struct { - ID networkid.UserID `json:"id"` - Name string `json:"name,omitempty"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Identifiers []string `json:"identifiers,omitempty"` - MXID id.UserID `json:"mxid,omitempty"` - DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` - - Portal *bridgev2.Portal `json:"-"` - Ghost *bridgev2.Ghost `json:"-"` - JustCreated bool `json:"-"` -} - -var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request") - -func ResolveIdentifier( - ctx context.Context, - login *bridgev2.UserLogin, - identifier string, - createChat bool, -) (*RespResolveIdentifier, error) { - api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers")) - } - var resp *bridgev2.ResolveIdentifierResponse - parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier)) - validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - if ok && (!vOK || validator.ValidateUserID(parsedUserID)) { - ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID") - return nil, err - } - resp = &bridgev2.ResolveIdentifierResponse{ - Ghost: ghost, - UserID: parsedUserID, - } - gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI) - if ok && createChat { - resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat") - return nil, err - } - } else if createChat || ghost.Name == "" { - zerolog.Ctx(ctx).Debug(). - Bool("create_chat", createChat). - Bool("has_name", ghost.Name != ""). - Msg("Falling back to resolving identifier") - resp = nil - identifier = string(parsedUserID) - } - } - if resp == nil { - var err error - resp, err = api.ResolveIdentifier(ctx, identifier, createChat) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") - return nil, err - } else if resp == nil { - return nil, nil - } - } - apiResp := &RespResolveIdentifier{ - ID: resp.UserID, - Ghost: resp.Ghost, - } - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(ctx, resp.UserInfo) - } - apiResp.Name = resp.Ghost.Name - apiResp.AvatarURL = resp.Ghost.AvatarMXC - apiResp.Identifiers = resp.Ghost.Identifiers - apiResp.MXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - apiResp.Name = *resp.UserInfo.Name - } - if resp.Chat != nil { - if resp.Chat.PortalKey.IsEmpty() { - return nil, ErrNoPortalKey - } - if resp.Chat.Portal == nil { - var err error - resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) - } - } - resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID) - if createChat && resp.Chat.Portal.MXID == "" { - apiResp.JustCreated = true - err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) - } - } - apiResp.Portal = resp.Chat.Portal - apiResp.DMRoomID = resp.Chat.Portal.MXID - } - return apiResp, nil -} diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 3775c825..a79d56e3 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -63,14 +63,7 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve return true } -var ( - ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) - ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage() -) - -func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { +func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands log := zerolog.Ctx(ctx) @@ -82,34 +75,37 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH log.Err(err).Msg("Failed to get sender user for incoming Matrix event") status := WrapErrorInStatus(fmt.Errorf("%w: failed to get sender user: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return EventHandlingResultFailed + return } else if sender == nil { log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") - br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) - return EventHandlingResultFailed + status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return } else if !sender.Permissions.SendEvents { if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { - br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt)) + status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) } - return EventHandlingResultIgnored + return } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { - return EventHandlingResultIgnored + return } } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") - br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) - return EventHandlingResultIgnored + status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return } if evt.Type == event.EventMessage && sender != nil { msg := evt.Content.AsMessage() msg.RemoveReplyFallback() - msg.RemovePerMessageProfileFallback() if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { if !sender.Permissions.Commands { - br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt)) - return EventHandlingResultIgnored + status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return } - go br.Commands.Handle( + br.Commands.Handle( ctx, evt.RoomID, evt.ID, @@ -117,118 +113,41 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), msg.RelatesTo.GetReplyTo(), ) - return EventHandlingResultQueued + return } } if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { - return br.handleBotInvite(ctx, evt, sender) - } else if sender != nil && evt.RoomID == sender.ManagementRoom { - if evt.Type == event.StateMember && evt.Content.AsMember().Membership == event.MembershipLeave && (evt.GetStateKey() == br.Bot.GetMXID().String() || evt.GetStateKey() == sender.MXID.String()) { - sender.ManagementRoom = "" - err := br.DB.User.Update(ctx, sender.User) - if err != nil { - log.Err(err).Msg("Failed to clear user's management room in database") - return EventHandlingResultFailed - } else { - log.Debug().Msg("Cleared user's management room due to leave event") - } - } - return EventHandlingResultSuccess + br.handleBotInvite(ctx, evt, sender) + return } portal, err := br.GetPortalByMXID(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to get portal for incoming Matrix event") status := WrapErrorInStatus(fmt.Errorf("%w: failed to get portal: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return EventHandlingResultFailed + return } else if portal != nil { - return portal.queueEvent(ctx, &portalMatrixEvent{ + portal.queueEvent(ctx, &portalMatrixEvent{ evt: evt, sender: sender, }) } else if evt.Type == event.StateMember && br.IsGhostMXID(id.UserID(evt.GetStateKey())) && evt.Content.AsMember().Membership == event.MembershipInvite && evt.Content.AsMember().IsDirect { - return br.handleGhostDMInvite(ctx, evt, sender) + br.handleGhostDMInvite(ctx, evt, sender) } else { status := WrapErrorInStatus(ErrNoPortal) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return EventHandlingResultIgnored } } -type EventHandlingResult struct { - Success bool - Ignored bool - Queued bool - - SkipStateEcho bool - - // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. - Error error - // Whether the Error should be sent as a MSS event. - SendMSS bool - - // EventID from the network - EventID id.EventID - // Stream order from the network - StreamOrder int64 -} - -func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult { - ehr.EventID = id - return ehr -} - -func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult { - ehr.StreamOrder = order - return ehr -} - -func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { - if err == nil { - return ehr - } - ehr.Error = err - ehr.Success = false - return ehr -} - -func (ehr EventHandlingResult) WithMSS() EventHandlingResult { - ehr.SendMSS = true - return ehr -} - -func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult { - ehr.SkipStateEcho = skip - return ehr -} - -func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { - if err == nil { - return ehr - } - return ehr.WithError(err).WithMSS() -} - -var ( - EventHandlingResultFailed = EventHandlingResult{} - EventHandlingResultQueued = EventHandlingResult{Success: true, Queued: true} - EventHandlingResultSuccess = EventHandlingResult{Success: true} - EventHandlingResultIgnored = EventHandlingResult{Success: true, Ignored: true} -) - -func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { - return ul.Bridge.QueueRemoteEvent(ul, evt) -} - -func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult { +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log - ctx := log.WithContext(br.BackgroundCtx) + ctx := log.WithContext(context.TODO()) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) isUncertain := ok && maybeUncertain.PortalReceiverIsUncertain() key := evt.GetPortalKey() var portal *Portal var err error - if isUncertain && !br.Config.SplitPortals { + if isUncertain { portal, err = br.GetExistingPortalByKey(ctx, key) } else { portal, err = br.GetPortalByKey(ctx, key) @@ -236,18 +155,18 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandl if err != nil { log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). Msg("Failed to get portal to handle remote event") - return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err)) + return } else if portal == nil { log.Warn(). Stringer("event_type", evt.GetType()). Object("portal_key", key). Bool("uncertain_receiver", isUncertain). Msg("Portal not found to handle remote event") - return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler) + return } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) - return portal.queueEvent(ctx, &portalRemoteEvent{ + portal.queueEvent(ctx, &portalRemoteEvent{ evt: evt, source: login, }) diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index 56e3a6b1..c725141b 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -65,19 +65,14 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) type ChatDelete struct { EventMeta OnlyForMe bool - Children bool } -var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil) +var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) func (evt *ChatDelete) DeleteOnlyForMe() bool { return evt.OnlyForMe } -func (evt *ChatDelete) DeleteChildren() bool { - return evt.Children -} - // ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. type ChatInfoChange struct { EventMeta diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index f8f8d7e1..928bffc9 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -20,7 +20,6 @@ type Message[T any] struct { Data T ID networkid.MessageID - TransactionID networkid.TransactionID TargetMessage networkid.MessageID ConvertMessageFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, data T) (*bridgev2.ConvertedMessage, error) @@ -29,10 +28,9 @@ type Message[T any] struct { } var ( - _ bridgev2.RemoteMessage = (*Message[any])(nil) - _ bridgev2.RemoteEdit = (*Message[any])(nil) - _ bridgev2.RemoteMessageUpsert = (*Message[any])(nil) - _ bridgev2.RemoteMessageWithTransactionID = (*Message[any])(nil) + _ bridgev2.RemoteMessage = (*Message[any])(nil) + _ bridgev2.RemoteEdit = (*Message[any])(nil) + _ bridgev2.RemoteMessageUpsert = (*Message[any])(nil) ) func (evt *Message[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { @@ -55,61 +53,14 @@ func (evt *Message[T]) GetTargetMessage() networkid.MessageID { return evt.TargetMessage } -func (evt *Message[T]) GetTransactionID() networkid.TransactionID { - return evt.TransactionID -} - -// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data. -type PreConvertedMessage struct { - EventMeta - Data *bridgev2.ConvertedMessage - ID networkid.MessageID - TransactionID networkid.TransactionID - - HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) -} - -var ( - _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil) - _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil) - _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil) -) - -func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { - return evt.Data, nil -} - -func (evt *PreConvertedMessage) GetID() networkid.MessageID { - return evt.ID -} - -func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID { - return evt.TransactionID -} - -func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { - if evt.HandleExistingFunc == nil { - return bridgev2.UpsertResult{}, nil - } - return evt.HandleExistingFunc(ctx, portal, intent, existing) -} - type MessageRemove struct { EventMeta TargetMessage networkid.MessageID - OnlyForMe bool } -var ( - _ bridgev2.RemoteMessageRemove = (*MessageRemove)(nil) - _ bridgev2.RemoteDeleteOnlyForMe = (*MessageRemove)(nil) -) +var _ bridgev2.RemoteMessageRemove = (*MessageRemove)(nil) func (evt *MessageRemove) GetTargetMessage() networkid.MessageID { return evt.TargetMessage } - -func (evt *MessageRemove) DeleteOnlyForMe() bool { - return evt.OnlyForMe -} diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 96c8a9c5..a6b278fc 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -7,7 +7,6 @@ package simplevent import ( - "context" "time" "github.com/rs/zerolog" @@ -26,10 +25,6 @@ type EventMeta struct { CreatePortal bool Timestamp time.Time StreamOrder int64 - - PreHandleFunc func(context.Context, *bridgev2.Portal) - PostHandleFunc func(context.Context, *bridgev2.Portal) - MutateContextFunc func(context.Context) context.Context } var ( @@ -38,9 +33,6 @@ var ( _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) _ bridgev2.RemoteEventWithStreamOrder = (*EventMeta)(nil) - _ bridgev2.RemotePreHandler = (*EventMeta)(nil) - _ bridgev2.RemotePostHandler = (*EventMeta)(nil) - _ bridgev2.RemoteEventWithContextMutation = (*EventMeta)(nil) ) func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { @@ -80,74 +72,3 @@ func (evt *EventMeta) GetType() bridgev2.RemoteEventType { func (evt *EventMeta) ShouldCreatePortal() bool { return evt.CreatePortal } - -func (evt *EventMeta) PreHandle(ctx context.Context, portal *bridgev2.Portal) { - if evt.PreHandleFunc != nil { - evt.PreHandleFunc(ctx, portal) - } -} - -func (evt *EventMeta) PostHandle(ctx context.Context, portal *bridgev2.Portal) { - if evt.PostHandleFunc != nil { - evt.PostHandleFunc(ctx, portal) - } -} - -func (evt *EventMeta) MutateContext(ctx context.Context) context.Context { - if evt.MutateContextFunc == nil { - return ctx - } - return evt.MutateContextFunc(ctx) -} - -func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta { - evt.Type = t - return evt -} - -func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { - evt.LogContext = f - return evt -} - -func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { - origFunc := evt.LogContext - if origFunc == nil { - evt.LogContext = f - return evt - } - evt.LogContext = func(c zerolog.Context) zerolog.Context { - return f(origFunc(c)) - } - return evt -} - -func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { - evt.PortalKey = p - return evt -} - -func (evt EventMeta) WithUncertainReceiver(u bool) EventMeta { - evt.UncertainReceiver = u - return evt -} - -func (evt EventMeta) WithSender(s bridgev2.EventSender) EventMeta { - evt.Sender = s - return evt -} - -func (evt EventMeta) WithCreatePortal(c bool) EventMeta { - evt.CreatePortal = c - return evt -} - -func (evt EventMeta) WithTimestamp(t time.Time) EventMeta { - evt.Timestamp = t - return evt -} - -func (evt EventMeta) WithStreamOrder(s int64) EventMeta { - evt.StreamOrder = s - return evt -} diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go index 41614e40..3565986b 100644 --- a/bridgev2/simplevent/receipt.go +++ b/bridgev2/simplevent/receipt.go @@ -19,8 +19,6 @@ type Receipt struct { LastTarget networkid.MessageID Targets []networkid.MessageID ReadUpTo time.Time - - ReadUpToStreamOrder int64 } var ( @@ -40,10 +38,6 @@ func (evt *Receipt) GetReadUpTo() time.Time { return evt.ReadUpTo } -func (evt *Receipt) GetReadUpToStreamOrder() int64 { - return evt.ReadUpToStreamOrder -} - type MarkUnread struct { EventMeta Unread bool diff --git a/bridgev2/space.go b/bridgev2/space.go index 2ca2bce3..17388f3e 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -43,7 +43,7 @@ func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { } } if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { - go ul.tryAddPortalToSpace(context.WithoutCancel(ctx), portal, userPortal.CopyWithoutValues()) + go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) } } } @@ -171,10 +171,6 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { // TODO remove this after initial_members is supported in hungryserv req.BeeperAutoJoinInvites = true } - pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI) - if ok { - pfc.CustomizePersonalFilteringSpace(req) - } ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) if err != nil { return "", fmt.Errorf("failed to create space room: %w", err) diff --git a/bridgev2/status/localbridgestate.go b/bridgev2/status/localbridgestate.go deleted file mode 100644 index 3ad66538..00000000 --- a/bridgev2/status/localbridgestate.go +++ /dev/null @@ -1,23 +0,0 @@ -package status - -type LocalBridgeAccountState string - -const ( - // LocalBridgeAccountStateSetup means the user wants this account to be setup and connected - LocalBridgeAccountStateSetup LocalBridgeAccountState = "SETUP" - // LocalBridgeAccountStateDeleted means the user wants this account to be deleted - LocalBridgeAccountStateDeleted LocalBridgeAccountState = "DELETED" -) - -type LocalBridgeDeviceState string - -const ( - // LocalBridgeDeviceStateSetup means this device is setup to be connected to this account - LocalBridgeDeviceStateSetup LocalBridgeDeviceState = "SETUP" - // LocalBridgeDeviceStateLoggedOut means the user has logged this particular device out while wanting their other devices to remain setup - LocalBridgeDeviceStateLoggedOut LocalBridgeDeviceState = "LOGGED_OUT" - // LocalBridgeDeviceStateError means this particular device has fallen into a persistent error state that may need user intervention to fix - LocalBridgeDeviceStateError LocalBridgeDeviceState = "ERROR" - // LocalBridgeDeviceStateDeleted means this particular device has cleaned up after the account as a whole was requested to be deleted - LocalBridgeDeviceStateDeleted LocalBridgeDeviceState = "DELETED" -) diff --git a/bridgev2/unorganized-docs/FEATURES.md b/bridgev2/unorganized-docs/FEATURES.md index 73da364d..908ca975 100644 --- a/bridgev2/unorganized-docs/FEATURES.md +++ b/bridgev2/unorganized-docs/FEATURES.md @@ -5,45 +5,45 @@ * [x] Attachments * [ ] Polls * [x] Replies - * [x] Threads + * [ ] Threads * [x] Edits * [x] Reactions - * [x] Reaction mass-syncing + * [ ] Reaction mass-syncing * [x] Deletions * [x] Message status events and error notices - * [x] Backfilling history + * [ ] Backfilling history * [x] Login * [x] Logout -* [x] Re-login after credential expiry -* [x] Disappearing messages +* [ ] Re-login after credential expiry +* [ ] Disappearing messages * [x] Read receipts * [ ] Presence -* [x] Typing notifications -* [x] Spaces -* [x] Relay mode -* [x] Chat metadata - * [x] Archive/low priority - * [x] Pin/favorite - * [x] Mark unread - * [x] Mute status - * [x] Temporary mutes ("snooze") +* [ ] Typing notifications +* [ ] Spaces +* [ ] Relay mode +* [ ] Chat metadata + * [ ] Archive/low priority + * [ ] Pin/favorite + * [ ] Mark unread + * [ ] Mute status + * [ ] Temporary mutes ("snooze") * [x] User metadata (name/avatar) -* [x] Group metadata - * [x] Initial meta and full resyncs +* [ ] Group metadata + * [ ] Initial meta and full resyncs * [x] Name, avatar, topic * [x] Members - * [x] Permissions - * [x] Change events - * [x] Name, avatar, topic - * [x] Members (join, leave, invite, kick, ban, knock) - * [x] Permissions (promote, demote) + * [ ] Permissions + * [ ] Change events + * [ ] Name, avatar, topic + * [ ] Members (join, leave, invite, kick, ban, knock) + * [ ] Permissions (promote, demote) * [ ] Misc actions * [ ] Invites / accepting message requests - * [x] Create group - * [x] Create DM - * [x] Get contact list - * [x] Check if identifier is on remote network - * [x] Search users on remote network + * [ ] Create group + * [ ] Create DM + * [ ] Get contact list + * [ ] Check if identifier is on remote network + * [ ] Search users on remote network * [ ] Delete chat * [ ] Report spam * [ ] Custom emojis diff --git a/bridgev2/user.go b/bridgev2/user.go index 9a7896d6..7dc9959a 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -165,26 +165,17 @@ func (user *User) GetUserLoginIDs() []networkid.UserLoginID { return maps.Keys(user.logins) } -// Deprecated: renamed to GetUserLogins func (user *User) GetCachedUserLogins() []*UserLogin { - return user.GetUserLogins() -} - -func (user *User) GetUserLogins() []*UserLogin { user.Bridge.cacheLock.Lock() defer user.Bridge.cacheLock.Unlock() return maps.Values(user.logins) } -func (user *User) HasTooManyLogins() bool { - return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins -} - func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) for key, val := range user.logins { - logins = append(logins, fmt.Sprintf("* `%s` (%s) - `%s`", key, val.RemoteName, val.BridgeState.GetPrev().StateEvent)) + logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.RemoteName)) } user.Bridge.cacheLock.Unlock() return strings.Join(logins, "\n") @@ -257,19 +248,3 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { func (user *User) Save(ctx context.Context) error { return user.Bridge.DB.User.Update(ctx, user.User) } - -func (br *Bridge) TrackAnalytics(userID id.UserID, event string, props map[string]any) { - analyticSender, ok := br.Matrix.(MatrixConnectorWithAnalytics) - if ok { - analyticSender.TrackAnalytics(userID, event, props) - } -} - -func (user *User) TrackAnalytics(event string, props map[string]any) { - user.Bridge.TrackAnalytics(user.MXID, event, props) -} - -func (ul *UserLogin) TrackAnalytics(event string, props map[string]any) { - // TODO include user login ID? - ul.Bridge.TrackAnalytics(ul.UserMXID, event, props) -} diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index d56dc4cc..017df773 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -10,7 +10,6 @@ import ( "cmp" "context" "fmt" - "maps" "slices" "sync" "time" @@ -18,10 +17,10 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exsync" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" ) @@ -38,7 +37,6 @@ type UserLogin struct { spaceCreateLock sync.Mutex deleteLock sync.Mutex - disconnectOnce sync.Once } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { @@ -51,8 +49,6 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } - // TODO if loading the user caused the provided userlogin to be loaded, cancel here? - // Currently this will double-load it } userLogin := &UserLogin{ UserLogin: dbUserLogin, @@ -66,9 +62,6 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { userLogin.Log.Err(err).Msg("Failed to load user login") return nil, nil - } else if userLogin.Client == nil { - userLogin.Log.Error().Msg("LoadUserLogin didn't fill Client") - return nil, nil } userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) user.logins[userLogin.ID] = userLogin @@ -103,13 +96,6 @@ func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User) } func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { - if portal.Receiver != "" { - ul := br.GetCachedUserLoginByID(portal.Receiver) - if ul == nil { - return nil, nil - } - return []*UserLogin{ul}, nil - } logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal) if err != nil { return nil, err @@ -143,12 +129,6 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } -func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) { - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - return slices.Collect(maps.Values(br.userLoginsByID)) -} - func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -237,23 +217,19 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params } ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) } - noCancelCtx := ul.Log.WithContext(user.Bridge.BackgroundCtx) - err = params.LoadUserLogin(noCancelCtx, ul) + err = params.LoadUserLogin(ul.Log.WithContext(context.Background()), ul) if err != nil { return nil, err - } else if ul.Client == nil { - ul.Log.Error().Msg("LoadUserLogin didn't fill Client in NewLogin") - return nil, fmt.Errorf("client not filled by LoadUserLogin") } if doInsert { - err = user.Bridge.DB.UserLogin.Insert(noCancelCtx, ul.UserLogin) + err = user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) if err != nil { return nil, err } user.Bridge.userLoginsByID[ul.ID] = ul user.logins[ul.ID] = ul } else { - err = ul.Save(noCancelCtx) + err = ul.Save(ctx) if err != nil { return nil, err } @@ -272,7 +248,6 @@ func (ul *UserLogin) Logout(ctx context.Context) { type DeleteOpts struct { LogoutRemote bool DontCleanupRooms bool - BlockingCleanup bool unlocked bool } @@ -290,8 +265,7 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts if opts.LogoutRemote { ul.Client.LogoutRemote(ctx) } else { - // we probably shouldn't delete the login if disconnect isn't finished - ul.Disconnect() + ul.Disconnect(nil) } var portals []*database.UserPortal var err error @@ -313,18 +287,10 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts if !opts.unlocked { ul.Bridge.cacheLock.Unlock() } - backgroundCtx := zerolog.Ctx(ctx).WithContext(ul.Bridge.BackgroundCtx) - if !opts.BlockingCleanup { - go ul.deleteSpace(backgroundCtx) - } else { - ul.deleteSpace(backgroundCtx) - } + backgroundCtx := context.WithoutCancel(ctx) + go ul.deleteSpace(backgroundCtx) if portals != nil { - if !opts.BlockingCleanup { - go ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) - } else { - ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) - } + go ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) } if state.StateEvent != "" { ul.BridgeState.Send(state) @@ -416,8 +382,6 @@ func (ul *UserLogin) kickUserFromPortal(ctx context.Context, up *database.UserPo portal, action, reason, err := ul.getLogoutAction(ctx, up, badCredentials) if err != nil { return nil, err - } else if portal == nil { - return nil, nil } zerolog.Ctx(ctx).Debug(). Str("login_id", string(ul.ID)). @@ -506,66 +470,36 @@ func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) erro return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) } -var _ status.BridgeStateFiller = (*UserLogin)(nil) +var _ status.StandaloneCustomBridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { state.UserID = ul.UserMXID - state.RemoteID = ul.ID + state.RemoteID = string(ul.ID) state.RemoteName = ul.RemoteName - state.RemoteProfile = ul.RemoteProfile - filler, ok := ul.Client.(status.BridgeStateFiller) + state.RemoteProfile = &ul.RemoteProfile + filler, ok := ul.Client.(status.StandaloneCustomBridgeStateFiller) if ok { return filler.FillBridgeState(state) } return state } -func (ul *UserLogin) Disconnect() { - ul.DisconnectWithTimeout(0) -} - -func (ul *UserLogin) DisconnectWithTimeout(timeout time.Duration) { - ul.disconnectOnce.Do(func() { - ul.disconnectInternal(timeout) - }) -} - -func (ul *UserLogin) disconnectInternal(timeout time.Duration) { - ul.BridgeState.StopUnknownErrorReconnect() - disconnected := make(chan struct{}) - go func() { - ul.Client.Disconnect() - close(disconnected) - }() - - var timeoutC <-chan time.Time - if timeout > 0 { - timeoutC = time.After(timeout) +func (ul *UserLogin) Disconnect(done func()) { + if done != nil { + defer done() } - for { + client := ul.Client + if client != nil { + ul.Client = nil + disconnected := make(chan struct{}) + go func() { + client.Disconnect() + close(disconnected) + }() select { case <-disconnected: - return - case <-time.After(2 * time.Second): - ul.Log.Warn().Msg("Client disconnection taking long") - case <-timeoutC: - ul.Log.Error().Msg("Client disconnection timed out") - return + case <-time.After(5 * time.Second): + ul.Log.Warn().Msg("Client disconnection timed out") } } } - -func (ul *UserLogin) recreateClient(ctx context.Context) error { - oldClient := ul.Client - err := ul.Bridge.Network.LoadUserLogin(ctx, ul) - if err != nil { - return err - } - if ul.Client == oldClient { - zerolog.Ctx(ctx).Warn().Msg("LoadUserLogin didn't update client") - } else { - zerolog.Ctx(ctx).Debug().Msg("Recreated user login client") - } - ul.disconnectOnce = sync.Once{} - return nil -} diff --git a/client.go b/client.go index 045d7b8e..750e3c25 100644 --- a/client.go +++ b/client.go @@ -13,18 +13,12 @@ import ( "net/http" "net/url" "os" - "runtime" - "slices" "strconv" - "strings" "sync/atomic" "time" "github.com/rs/zerolog" - "go.mau.fi/util/ptr" - "go.mau.fi/util/random" "go.mau.fi/util/retryafter" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/event" @@ -41,55 +35,33 @@ type CryptoHelper interface { } type VerificationHelper interface { - // Init initializes the helper. This should be called before any other - // methods. Init(context.Context) error - - // StartVerification starts an interactive verification flow with the given - // user via a to-device event. StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) - // StartInRoomVerification starts an interactive verification flow with the - // given user in the given room. StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) - - // AcceptVerification accepts a verification request. AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error - // DismissVerification dismisses a verification request. This will not send - // a cancellation to the other device. This method should only be called - // *before* the request has been accepted and will error otherwise. - DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error - // CancelVerification cancels a verification request. This method should - // only be called *after* the request has been accepted, although it will - // not error if called beforehand. CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error - // HandleScannedQRData handles the data from a QR code scan. HandleScannedQRData(ctx context.Context, data []byte) error - // ConfirmQRCodeScanned confirms that our QR code has been scanned. ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error - // StartSAS starts a SAS verification flow. StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error - // ConfirmSAS indicates that the user has confirmed that the SAS matches - // SAS shown on the other user's device. ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error } // Client represents a Matrix client. type Client struct { - HomeserverURL *url.URL // The base homeserver URL - UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. - DeviceID id.DeviceID // The device ID of the client. - AccessToken string // The access_token for the client. - UserAgent string // The value for the User-Agent header - Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. - Syncer Syncer // The thing which can process /sync responses - Store SyncStore // The thing which can store tokens/ids - StateStore StateStore - Crypto CryptoHelper - Verification VerificationHelper - SpecVersions *RespVersions - ExternalClient *http.Client // The HTTP client used for external (not matrix) media HTTP requests. + HomeserverURL *url.URL // The base homeserver URL + UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. + DeviceID id.DeviceID // The device ID of the client. + AccessToken string // The access_token for the client. + UserAgent string // The value for the User-Agent header + Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. + Syncer Syncer // The thing which can process /sync responses + Store SyncStore // The thing which can store tokens/ids + StateStore StateStore + Crypto CryptoHelper + Verification VerificationHelper + SpecVersions *RespVersions Log zerolog.Logger @@ -99,7 +71,6 @@ type Client struct { UpdateRequestOnRetry func(req *http.Request, cause error) *http.Request SyncPresence event.Presence - SyncTraceLog bool StreamSyncMinAge time.Duration @@ -111,16 +82,11 @@ type Client struct { // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool - ResponseSizeLimit int64 - txnID int32 // Should the ?user_id= query parameter be set in requests? // See https://spec.matrix.org/v1.6/application-service-api/#identity-assertion SetAppServiceUserID bool - // Should the org.matrix.msc3202.device_id query parameter be set in requests? - // See https://github.com/matrix-org/matrix-spec-proposals/pull/3202 - SetAppServiceDeviceID bool syncingID uint32 // Identifies the current Sync. Only one Sync can be active at any given time. } @@ -142,12 +108,6 @@ type IdentityServerInfo struct { // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) { - return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) -} - -const WellKnownMaxSize = 64 * 1024 - -func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, @@ -159,11 +119,10 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve return nil, err } - if runtime.GOOS != "js" { - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") - } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err @@ -172,15 +131,11 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve if resp.StatusCode == http.StatusNotFound { return nil, nil - } else if resp.ContentLength > WellKnownMaxSize { - return nil, errors.New(".well-known response too large") } - data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize)) + data, err := io.ReadAll(resp.Body) if err != nil { return nil, err - } else if len(data) >= WellKnownMaxSize { - return nil, errors.New(".well-known response too large") } var wellKnown ClientWellKnown @@ -256,7 +211,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { streamResp = true } timeout := 30000 - if isFailing || nextBatch == "" { + if isFailing { timeout = 0 } resSync, err := cli.FullSyncRequest(ctx, ReqSync{ @@ -330,38 +285,21 @@ type contextKey int const ( LogBodyContextKey contextKey = iota LogRequestIDContextKey - MaxAttemptsContextKey - SyncTokenContextKey ) func (cli *Client) RequestStart(req *http.Request) { - if cli != nil && cli.RequestHook != nil { + if cli.RequestHook != nil { cli.RequestHook(req) } } -// WithMaxRetries updates the context to set the maximum number of retries for any HTTP requests made with the context. -// -// 0 means the request will only be attempted once and will not be retried. -// Negative values will remove the override and fallback to the defaults. -func WithMaxRetries(ctx context.Context, maxRetries int) context.Context { - return context.WithValue(ctx, MaxAttemptsContextKey, maxRetries+1) -} - func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { - if cli == nil { - return - } var evt *zerolog.Event - if errors.Is(err, context.Canceled) { - evt = zerolog.Ctx(req.Context()).Warn() - } else if err != nil { + if err != nil { evt = zerolog.Ctx(req.Context()).Err(err) } else if handlerErr != nil { evt = zerolog.Ctx(req.Context()).Warn(). AnErr("body_parse_err", handlerErr) - } else if cli.SyncTraceLog && strings.HasSuffix(req.URL.Path, "/_matrix/client/v3/sync") { - evt = zerolog.Ctx(req.Context()).Trace() } else { evt = zerolog.Ctx(req.Context()).Debug() } @@ -386,18 +324,9 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } if body := req.Context().Value(LogBodyContextKey); body != nil { - switch typedLogBody := body.(type) { - case json.RawMessage: - evt.RawJSON("req_body", typedLogBody) - case string: - evt.Str("req_body", typedLogBody) - default: - panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body)) - } + evt.Interface("req_body", body) } - if errors.Is(err, context.Canceled) { - evt.Msg("Request canceled") - } else if err != nil { + if err != nil { evt.Msg("Request failed") } else if handlerErr != nil { evt.Msg("Request parsing failed") @@ -410,43 +339,32 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } -type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error) +type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - RequestBytes []byte - RequestBody io.Reader - RequestLength int64 - ResponseJSON interface{} - MaxAttempts int - BackoffDuration time.Duration - SensitiveContent bool - Handler ClientResponseHandler - DontReadResponse bool - ResponseSizeLimit int64 - Logger *zerolog.Logger - Client *http.Client + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBytes []byte + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + MaxAttempts int + BackoffDuration time.Duration + SensitiveContent bool + Handler ClientResponseHandler + DontReadResponse bool + Logger *zerolog.Logger + Client *http.Client } var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { - reqID := atomic.AddInt32(&requestID, 1) - logger := zerolog.Ctx(ctx) - if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { - logger = params.Logger - } - ctx = logger.With(). - Int32("req_id", reqID). - Logger().WithContext(ctx) - var logBody any - var reqBody io.Reader - var reqLen int64 + reqBody := params.RequestBody if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -457,38 +375,29 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } if params.SensitiveContent && !logSensitiveContent { logBody = "" - } else if len(jsonStr) > 32768 { - logBody = fmt.Sprintf("", len(jsonStr)) } else { - logBody = json.RawMessage(jsonStr) + logBody = params.RequestJSON } reqBody = bytes.NewReader(jsonStr) - reqLen = int64(len(jsonStr)) } else if params.RequestBytes != nil { logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes)) reqBody = bytes.NewReader(params.RequestBytes) - reqLen = int64(len(params.RequestBytes)) - } else if params.RequestBody != nil { - logBody = "" - reqLen = -1 - if params.RequestLength > 0 { - logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) - reqLen = params.RequestLength - } else if params.RequestLength == 0 { - zerolog.Ctx(ctx).Warn(). - Msg("RequestBody passed without specifying request length") - } - reqBody = params.RequestBody - if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok { - // Prevent HTTP from closing the request body, it might be needed for retries - reqBody = nopCloseSeeker{rsc} - } + params.RequestLength = int64(len(params.RequestBytes)) + } else if params.RequestLength > 0 && params.RequestBody != nil { + logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} - logBody = json.RawMessage("{}") + logBody = params.RequestJSON reqBody = bytes.NewReader([]byte("{}")) - reqLen = 2 } + reqID := atomic.AddInt32(&requestID, 1) + logger := zerolog.Ctx(ctx) + if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { + logger = params.Logger + } + ctx = logger.With(). + Int32("req_id", reqID). + Logger().WithContext(ctx) ctx = context.WithValue(ctx, LogBodyContextKey, logBody) ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID)) req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) @@ -504,7 +413,9 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } - req.ContentLength = reqLen + if params.RequestLength > 0 && params.RequestBody != nil { + req.ContentLength = params.RequestLength + } return req, nil } @@ -514,19 +425,8 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b } func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { - if cli == nil { - return nil, nil, ErrClientIsNil - } - if cli.HomeserverURL == nil || cli.HomeserverURL.Scheme == "" { - return nil, nil, ErrClientHasNoHomeserver - } if params.MaxAttempts == 0 { - maxAttempts, ok := ctx.Value(MaxAttemptsContextKey).(int) - if ok && maxAttempts > 0 { - params.MaxAttempts = maxAttempts - } else { - params.MaxAttempts = 1 + cli.DefaultHTTPRetries - } + params.MaxAttempts = 1 + cli.DefaultHTTPRetries } if params.BackoffDuration == 0 { if cli.DefaultHTTPBackoff == 0 { @@ -549,31 +449,14 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque params.Handler = handleNormalResponse } } - if cli.UserAgent != "" { - req.Header.Set("User-Agent", cli.UserAgent) - } + req.Header.Set("User-Agent", cli.UserAgent) if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } - if params.ResponseSizeLimit == 0 { - params.ResponseSizeLimit = cli.ResponseSizeLimit - } - if params.ResponseSizeLimit == 0 { - params.ResponseSizeLimit = DefaultResponseSizeLimit - } if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest( - req, - params.MaxAttempts-1, - params.BackoffDuration, - params.ResponseJSON, - params.Handler, - params.DontReadResponse, - params.ResponseSizeLimit, - params.Client, - ) + return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -584,17 +467,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry( - req *http.Request, - cause error, - retries int, - backoff time.Duration, - responseJSON any, - handler ClientResponseHandler, - dontReadResponse bool, - sizeLimit int64, - client *http.Client, -) ([]byte, *http.Response, error) { +func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { var err error @@ -616,37 +489,17 @@ func (cli *Client) doRetry( } } log.Warn().Err(cause). - Str("method", req.Method). - Str("url", req.URL.String()). Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") - select { - case <-time.After(backoff): - case <-req.Context().Done(): - if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { - return nil, nil, req.Context().Err() - } - } + time.Sleep(backoff) if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client) } -func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) { - if res.ContentLength > limit { - return nil, HTTPError{ - Request: req, - Response: res, - - Message: "not reading response", - WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), - } - } - contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1)) - if err == nil && len(contents) > int(limit) { - err = ErrBodyReadReachedLimit - } +func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) { + contents, err := io.ReadAll(res.Body) if err != nil { return nil, HTTPError{ Request: req, @@ -667,20 +520,17 @@ func closeTemp(log *zerolog.Logger, file *os.File) { } } -func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { +func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { log := zerolog.Ctx(req.Context()) file, err := os.CreateTemp("", "mautrix-response-") if err != nil { log.Warn().Err(err).Msg("Failed to create temporary file for streaming response") - _, err = handleNormalResponse(req, res, responseJSON, limit) + _, err = handleNormalResponse(req, res, responseJSON) return nil, err } defer closeTemp(log, file) - var n int64 - if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil { + if _, err = io.Copy(file, res.Body); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) - } else if n > limit { - return nil, ErrBodyReadReachedLimit } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { @@ -690,12 +540,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON any, lim } } -func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { return nil, nil } -func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { - if contents, err := readResponseBody(req, res, limit); err != nil { +func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { + if contents, err := readResponseBody(req, res); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -713,13 +563,8 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON an } } -const ErrorResponseSizeLimit = 512 * 1024 - -var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 - func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - defer res.Body.Close() - contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) + contents, err := readResponseBody(req, res) if err != nil { return contents, err } @@ -738,31 +583,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest( - req *http.Request, - retries int, - backoff time.Duration, - responseJSON any, - handler ClientResponseHandler, - dontReadResponse bool, - sizeLimit int64, - client *http.Client, -) ([]byte, *http.Response, error) { +func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) - duration := time.Since(startTime) + duration := time.Now().Sub(startTime) if res != nil && !dontReadResponse { defer res.Body.Close() } if err != nil { - // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry - canRetry := !errors.Is(err, context.Canceled) || - errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) - if retries > 0 && canRetry { - return cli.doRetry( - req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, - ) + if retries > 0 { + return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) } err = HTTPError{ Request: req, @@ -777,9 +608,7 @@ func (cli *Client) executeCompiledRequest( if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry( - req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, - ) + return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client) } var body []byte @@ -787,7 +616,7 @@ func (cli *Client) executeCompiledRequest( body, err = ParseErrorResponse(req, res) cli.LogRequestDone(req, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON, sizeLimit) + body, err = handler(req, res, responseJSON) cli.LogRequestDone(req, res, nil, err, len(body), duration) } return body, res, err @@ -795,6 +624,7 @@ func (cli *Client) executeCompiledRequest( // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { + urlPath := cli.BuildClientURL("v3", "account", "whoami") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -825,7 +655,6 @@ type ReqSync struct { FullState bool SetPresence event.Presence StreamResponse bool - UseStateAfter bool BeeperStreaming bool Client *http.Client } @@ -846,10 +675,9 @@ func (req *ReqSync) BuildQuery() map[string]string { if req.FullState { query["full_state"] = "true" } - if req.UseStateAfter { - query["use_state_after"] = "true" - } if req.BeeperStreaming { + // TODO remove this + query["streaming"] = "" query["com.beeper.streaming"] = "true" } return query @@ -871,7 +699,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Since(start) + duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -918,7 +746,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp return } -func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, @@ -942,7 +770,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[an // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(ctx, u, req) } @@ -951,7 +779,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRe // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } @@ -974,8 +802,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*R // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) { - _, uia, err := cli.Register(ctx, req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { + res, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -984,7 +812,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*R return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err := cli.Register(ctx, req) + res, _, err = cli.Register(ctx, req) if err != nil { return nil, err } @@ -1034,22 +862,6 @@ func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, e return } -// Create a device for an appservice user using MSC4190. -func (cli *Client) CreateDeviceMSC4190(ctx context.Context, deviceID id.DeviceID, initialDisplayName string) error { - if len(deviceID) == 0 { - deviceID = id.DeviceID(strings.ToUpper(random.String(10))) - } - _, err := cli.MakeRequest(ctx, http.MethodPut, cli.BuildClientURL("v3", "devices", deviceID), &ReqPutDevice{ - DisplayName: initialDisplayName, - }, nil) - if err != nil { - return err - } - cli.DeviceID = deviceID - cli.SetAppServiceDeviceID = true - return nil -} - // Logout the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { @@ -1083,19 +895,20 @@ func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, er return } -// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3joinroomidoralias +// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3joinroomidoralias // -// The last parameter contains optional extra fields and can be left nil. -func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias string, req *ReqJoinRoom) (resp *RespJoinRoom, err error) { - if req == nil { - req = &ReqJoinRoom{} +// If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will +// be JSON encoded and used as the request body. +func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { + var urlPath string + if serverName != "" { + urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ + "server_name": serverName, + }) + } else { + urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) } - urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "join", roomIDorAlias}, func(q url.Values) { - if len(req.Via) > 0 { - q["via"] = req.Via - } - }) - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if err != nil { @@ -1105,28 +918,6 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias string, req *ReqJ return } -// KnockRoom requests to join a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias -// -// The last parameter contains optional extra fields and can be left nil. -func (cli *Client) KnockRoom(ctx context.Context, roomIDorAlias string, req *ReqKnockRoom) (resp *RespKnockRoom, err error) { - if req == nil { - req = &ReqKnockRoom{} - } - urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "knock", roomIDorAlias}, func(q url.Values) { - if len(req.Via) > 0 { - q["via"] = req.Via - } - }) - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) - if err == nil && cli.StateStore != nil { - err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipKnock) - if err != nil { - err = fmt.Errorf("failed to update state store: %w", err) - } - } - return -} - // JoinRoomByID joins the client to a room ID. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidjoin // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. @@ -1148,54 +939,10 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs return } -func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) { - urlPath := cli.BuildClientURL("v3", "user_directory", "search") - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{ - SearchTerm: query, - Limit: limit, - }, &resp) - return -} - -func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { - supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms) - supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms) - if cli.SpecVersions != nil && !supportsUnstable && !supportsStable { - err = fmt.Errorf("server does not support fetching mutual rooms") - return - } - query := map[string]string{ - "user_id": otherUserID.String(), - } - if len(extras) > 0 { - query["from"] = extras[0].From - } - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "mutual_rooms"}, query) - if !supportsStable && supportsUnstable { - urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) - } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) - return -} - -func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via ...string) (resp *RespRoomSummary, err error) { - urlPath := ClientURLPath{"unstable", "im.nheko.summary", "summary", roomIDOrAlias} - if cli.SpecVersions.ContainsGreaterOrEqual(SpecV115) { - urlPath = ClientURLPath{"v1", "room_summary", roomIDOrAlias} - } - // TODO add version check after one is added to MSC3266 - fullURL := cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { - if len(via) > 0 { - q["via"] = via - } - }) - _, err = cli.MakeRequest(ctx, http.MethodGet, fullURL, nil, &resp) - return -} - // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { - err = cli.GetProfileField(ctx, mxid, "displayname", &resp) + urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1206,47 +953,25 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { - return cli.SetProfileField(ctx, "displayname", displayName) -} - -// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname -func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) - if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { - urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) - } - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ - key: value, - }, nil) - return -} - -// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname -func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) - if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { - urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) - } - _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) - return -} - -// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname -func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", userID, key) - if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { - urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) - } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into) + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") + s := struct { + DisplayName string `json:"displayname"` + }{displayName} + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) return } // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { + urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - err = cli.GetProfileField(ctx, mxid, "avatar_url", &s) + + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s) + if err != nil { + return + } url = s.AvatarURL return } @@ -1313,6 +1038,15 @@ func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, nam return nil } +type ReqSendEvent struct { + Timestamp int64 + TransactionID string + + DontEncrypt bool + + MeowEventID id.EventID +} + // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { @@ -1335,14 +1069,8 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } - if req.UnstableDelay > 0 { - queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) - } - if req.UnstableStickyDuration > 0 { - queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) - } - if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { + if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) if err != nil { @@ -1364,77 +1092,12 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } -// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint. -// contentJSON should be a value that can be encoded as JSON using json.Marshal. -func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { - var req ReqSendEvent - if len(extra) > 0 { - req = extra[0] - } - - var txnID string - if len(req.TransactionID) > 0 { - txnID = req.TransactionID - } else { - txnID = cli.TxnID() - } - - queryParams := map[string]string{} - if req.Timestamp > 0 { - queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) - } - - if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted { - var isEncrypted bool - isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) - if err != nil { - err = fmt.Errorf("failed to check if room is encrypted: %w", err) - return - } - if isEncrypted { - if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { - err = fmt.Errorf("failed to encrypt event: %w", err) - return - } - eventType = event.EventEncrypted - } - } - - urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID} - urlPath := cli.BuildURLWithQuery(urlData, queryParams) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - return -} - -// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey +// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { - var req ReqSendEvent - if len(extra) > 0 { - req = extra[0] - } - - queryParams := map[string]string{} - if req.MeowEventID != "" { - queryParams["fi.mau.event_id"] = req.MeowEventID.String() - } - if req.TransactionID != "" { - queryParams["fi.mau.transaction_id"] = req.TransactionID - } - if req.UnstableDelay > 0 { - queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) - } - if req.UnstableStickyDuration > 0 { - queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) - } - if req.Timestamp > 0 { - queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) - } - - urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} - urlPath := cli.BuildURLWithQuery(urlData, queryParams) +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { + urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil && req.UnstableDelay == 0 { + if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return @@ -1442,44 +1105,14 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -// -// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead. func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{ - Timestamp: ts, + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ + "ts": strconv.FormatInt(ts, 10), }) - return -} - -func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) { - query := map[string]string{} - if req.DelayID != "" { - query["delay_id"] = string(req.DelayID) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) + if err == nil && cli.StateStore != nil { + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } - if req.Status != "" { - query["status"] = string(req.Status) - } - if req.NextBatch != "" { - query["next_batch"] = req.NextBatch - } - - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp) - - // Migration: merge old keys with new ones - if resp != nil { - resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...) - resp.DelayedEvents = nil - resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...) - resp.FinalisedEvents = nil - } - - return -} - -func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { - urlPath := cli.BuildClientURL("unstable", "org.matrix.msc4140", "delayed_events", req.DelayID) - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } @@ -1534,19 +1167,6 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id return } -func (cli *Client) UnstableRedactUserEvents(ctx context.Context, roomID id.RoomID, userID id.UserID, req *ReqRedactUser) (resp *RespRedactUserEvents, err error) { - if req == nil { - req = &ReqRedactUser{} - } - query := map[string]string{} - if req.Limit > 0 { - query["limit"] = strconv.Itoa(req.Limit) - } - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4194", "rooms", roomID, "redact", "user", userID}, query) - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) - return -} - // CreateRoom creates a new Matrix room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom // // resp, err := cli.CreateRoom(&mautrix.ReqCreateRoom{ @@ -1564,10 +1184,6 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update creator membership in state store after creating room") } for _, evt := range req.InitialState { - evt.RoomID = resp.RoomID - if evt.StateKey == nil { - evt.StateKey = ptr.Ptr("") - } UpdateStateStore(ctx, cli.StateStore, evt) } inviteMembership := event.MembershipInvite @@ -1582,6 +1198,9 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update membership in state store after creating room") } } + for _, evt := range req.InitialState { + cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) + } } return } @@ -1692,14 +1311,15 @@ func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err return cli.GetPresence(ctx, cli.UserID) } -func (cli *Client) SetPresence(ctx context.Context, presence ReqPresence) (err error) { +func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { + req := ReqPresence{Presence: status} u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest(ctx, http.MethodPut, u, presence, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, nil) return } func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { - if cli == nil || cli.StateStore == nil { + if cli.StateStore == nil { return } fakeEvt := &event.Event{ @@ -1731,8 +1351,8 @@ func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.R UpdateStateStore(ctx, cli.StateStore, fakeEvt) } -// StateEvent gets the content of a single state event in a room. -// It will attempt to JSON unmarshal into the given "outContent" struct with the HTTP response body, or return an error. +// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with +// the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) @@ -1743,43 +1363,12 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e return } -// FullStateEvent gets a single state event in a room. Unlike [StateEvent], this gets the entire event -// (including details like the sender and timestamp). -// This requires the server to support the ?format=event query parameter, which is currently missing from the spec. -// See https://github.com/matrix-org/matrix-spec/issues/1047 for more info -func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (evt *event.Event, err error) { - u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ - "format": "event", - }) - _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt) - if evt != nil { - evt.Type.Class = event.StateEventType - _ = evt.Content.ParseRaw(evt.Type) - if evt.RoomID == "" { - evt.RoomID = roomID - } - } - if err == nil && cli.StateStore != nil { - UpdateStateStore(ctx, cli.StateStore, evt) - } - return -} - // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. -func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { - if res.ContentLength > limit { - return nil, HTTPError{ - Request: req, - Response: res, - - Message: "not reading response", - WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), - } - } +func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response - dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + dec := json.NewDecoder(res.Body) arrayStart, err := dec.Token() if err != nil { @@ -1813,8 +1402,6 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any return nil, nil } -type RoomStateMap = map[event.Type]map[string]*event.Event - // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { @@ -1824,63 +1411,36 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt ResponseJSON: &stateMap, Handler: parseRoomStateArray, }) - if stateMap != nil { - pls, ok := stateMap[event.StatePowerLevels][""] - if ok { - pls.Content.AsPowerLevels().CreateEvent = stateMap[event.StateCreate][""] - } - } if err == nil && cli.StateStore != nil { - for evtType, evts := range stateMap { - if evtType == event.StateMember { - continue - } + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching state") + } + for _, evts := range stateMap { for _, evt := range evts { - if evt.RoomID == "" { - evt.RoomID = roomID - } UpdateStateStore(ctx, cli.StateStore, evt) } } - updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, maps.Values(stateMap[event.StateMember])) - if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(updateErr). - Stringer("room_id", roomID). - Msg("Failed to update members in state store after fetching members") - } - } - return -} - -// StateAsArray gets all the state in a room as an array. It does not update the state store. -// Use State to get the events as a map and also update the state store. -func (cli *Client) StateAsArray(ctx context.Context, roomID id.RoomID) (state []*event.Event, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v3", "rooms", roomID, "state"), nil, &state) - if err == nil { - for _, evt := range state { - evt.Type.Class = event.StateEventType - } } return } // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp) - return -} - -func (cli *Client) RequestOpenIDToken(ctx context.Context) (resp *RespOpenIDToken, err error) { - _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildClientURL("v3", "user", cli.UserID, "openid", "request_token"), nil, &resp) + var u string + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + u = cli.BuildClientURL("v1", "media", "config") + } else { + u = cli.BuildURL(MediaURLPath{"v3", "config"}) + } + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) return } // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { - if cli == nil { - return nil, ErrClientIsNil - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) if err != nil { return nil, err @@ -1896,51 +1456,94 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa return cli.Upload(ctx, res.Body, res.Header.Get("Content-Type"), res.ContentLength) } +// Deprecated: unauthenticated media is deprecated as of Matrix v1.11. Use [Download] or [DownloadBytes] instead. +func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { + return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) +} + +func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) { + log := zerolog.Ctx(req.Context()) + if req.Body != nil { + var err error + if req.GetBody != nil { + req.Body, err = req.GetBody() + if err != nil { + log.Warn().Err(err).Msg("Failed to get new body to retry request") + return nil, cause + } + } else if bodySeeker, ok := req.Body.(io.ReadSeeker); ok { + _, err = bodySeeker.Seek(0, io.SeekStart) + if err != nil { + log.Warn().Err(err).Msg("Failed to seek to beginning of request body") + return nil, cause + } + } else { + log.Warn().Msg("Failed to get new body to retry request: GetBody is nil and Body is not an io.ReadSeeker") + return nil, cause + } + } + log.Warn().Err(cause). + Int("retry_in_seconds", int(backoff.Seconds())). + Msg("Request failed, retrying") + time.Sleep(backoff) + return cli.doMediaRequest(req, retries-1, backoff*2) +} + +func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.Duration) (*http.Response, error) { + cli.RequestStart(req) + startTime := time.Now() + res, err := cli.Client.Do(req) + duration := time.Now().Sub(startTime) + if err != nil { + if retries > 0 { + return cli.doMediaRetry(req, err, retries, backoff) + } + err = HTTPError{ + Request: req, + Response: res, + + Message: "request error", + WrappedError: err, + } + cli.LogRequestDone(req, res, err, nil, 0, duration) + return nil, err + } + + if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { + backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) + return cli.doMediaRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff) + } + + if res.StatusCode < 200 || res.StatusCode >= 300 { + var body []byte + body, err = ParseErrorResponse(req, res) + cli.LogRequestDone(req, res, err, nil, len(body), duration) + } else { + cli.LogRequestDone(req, res, nil, nil, -1, duration) + } + return res, err +} + func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { - if mxcURL.IsEmpty() { - return nil, fmt.Errorf("empty mxc uri provided to Download") + ctxLog := zerolog.Ctx(ctx) + if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { + ctx = cli.Log.WithContext(ctx) } - _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ - Method: http.MethodGet, - URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), - DontReadResponse: true, - }) - return resp, err -} - -type DownloadThumbnailExtra struct { - Method string - Animated bool -} - -func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { - if mxcURL.IsEmpty() { - return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail") + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), + DontReadResponse: true, + }) + return resp, err + } else { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)") + return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) } - if len(extras) > 1 { - panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) - } - var extra DownloadThumbnailExtra - if len(extras) == 1 { - extra = extras[0] - } - path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID} - query := map[string]string{ - "height": strconv.Itoa(height), - "width": strconv.Itoa(width), - } - if extra.Method != "" { - query["method"] = extra.Method - } - if extra.Animated { - query["animated"] = "true" - } - _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ - Method: http.MethodGet, - URL: cli.BuildURLWithQuery(path, query), - DontReadResponse: true, - }) - return resp, err } func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { @@ -1952,27 +1555,12 @@ func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]b return io.ReadAll(resp.Body) } -type ReqCreateMXC struct { - BeeperUniqueID string - BeeperRoomID id.RoomID -} - // CreateMXC creates a blank Matrix content URI to allow uploading the content asynchronously later. // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create -func (cli *Client) CreateMXC(ctx context.Context, extra ...ReqCreateMXC) (*RespCreateMXC, error) { +func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { var m RespCreateMXC - query := map[string]string{} - if len(extra) > 0 { - if extra[0].BeeperUniqueID != "" { - query["com.beeper.unique_id"] = extra[0].BeeperUniqueID - } - if extra[0].BeeperRoomID != "" { - query["com.beeper.room_id"] = string(extra[0].BeeperRoomID) - } - } - createURL := cli.BuildURLWithQuery(MediaURLPath{"v1", "create"}, query) - _, err := cli.MakeRequest(ctx, http.MethodPost, createURL, nil, &m) + _, err := cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(MediaURLPath{"v1", "create"}), nil, &m) return &m, err } @@ -1984,20 +1572,14 @@ func (cli *Client) CreateMXC(ctx context.Context, extra ...ReqCreateMXC) (*RespC func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCreateMXC, error) { resp, err := cli.CreateMXC(ctx) if err != nil { - req.DoneCallback() return nil, err } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL - if req.AsyncContext == nil { - req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background()) - } go func() { - _, err = cli.UploadMedia(req.AsyncContext, req) + _, err = cli.UploadMedia(ctx, req) if err != nil { - zerolog.Ctx(req.AsyncContext).Err(err). - Stringer("mxc", req.MXC). - Msg("Async upload of media failed") + cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") } }() return resp, nil @@ -2033,9 +1615,6 @@ type ReqUploadMedia struct { ContentType string FileName string - AsyncContext context.Context - DoneCallback func() - // MXC specifies an existing MXC URI which doesn't have content yet to upload into. // See https://spec.matrix.org/unstable/client-server-api/#put_matrixmediav3uploadservernamemediaid MXC id.ContentURI @@ -2046,25 +1625,16 @@ type ReqUploadMedia struct { } func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { - cli.Log.Debug(). - Str("url", url). - Int64("content_length", contentLength). - Msg("Uploading media to external URL") + cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err } req.ContentLength = contentLength req.Header.Set("Content-Type", contentType) - if cli.UserAgent != "" { - req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") - } + req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") - if cli.ExternalClient != nil { - return cli.ExternalClient.Do(req) - } else { - return http.DefaultClient.Do(req) - } + return http.DefaultClient.Do(req) } func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { @@ -2073,9 +1643,6 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* if data.ContentBytes != nil { data.ContentLength = int64(len(data.ContentBytes)) reader = bytes.NewReader(data.ContentBytes) - } else if rsc, ok := reader.(io.ReadSeekCloser); ok { - // Prevent HTTP from closing the request body, it might be needed for retries - reader = nopCloseSeeker{rsc} } readerSeeker, canSeek := reader.(io.ReadSeeker) if !canSeek { @@ -2089,25 +1656,14 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* break } err = fmt.Errorf("HTTP %d", resp.StatusCode) - } else if errors.Is(err, context.Canceled) { - cli.Log.Warn().Str("url", data.UnstableUploadURL).Msg("External media upload canceled") - return nil, err } if retries <= 0 { cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). Msg("Error uploading media to external URL, not retrying") return nil, err } - backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries) - cli.Log.Warn().Err(err). - Str("url", data.UnstableUploadURL). - Int("retry_in_seconds", int(backoff.Seconds())). + cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). Msg("Error uploading media to external URL, retrying") - select { - case <-time.After(backoff): - case <-ctx.Done(): - return nil, ctx.Err() - } retries-- _, err = readerSeeker.Seek(0, io.SeekStart) if err != nil { @@ -2131,23 +1687,9 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* return m, nil } -type nopCloseSeeker struct { - io.ReadSeeker -} - -func (nopCloseSeeker) Close() error { - return nil -} - // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav3upload func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { - if data.DoneCallback != nil { - defer data.DoneCallback() - } - if cli == nil { - return nil, ErrClientIsNil - } if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") @@ -2188,7 +1730,13 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { - reqURL := cli.BuildURLWithQuery(ClientURLPath{"v1", "media", "preview_url"}, map[string]string{ + var urlPath PrefixableURLPath + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + urlPath = ClientURLPath{"v1", "media", "preview_url"} + } else { + urlPath = MediaURLPath{"v3", "preview_url"} + } + reqURL := cli.BuildURLWithQuery(urlPath, map[string]string{ "url": url, }) var output RespPreviewURL @@ -2204,26 +1752,24 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { - fakeEvents := make([]*event.Event, len(resp.Joined)) - i := 0 - for userID, member := range resp.Joined { - fakeEvents[i] = &event.Event{ - StateKey: ptr.Ptr(userID.String()), - Type: event.StateMember, - RoomID: roomID, - Content: event.Content{Parsed: &event.MemberEventContent{ - Membership: event.MembershipJoin, - AvatarURL: id.ContentURIString(member.AvatarURL), - Displayname: member.DisplayName, - }}, - } - i++ - } - updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, fakeEvents, event.MembershipJoin) - if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(updateErr). + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). Stringer("room_id", roomID). - Msg("Failed to update members in state store after fetching joined members") + Msg("Failed to clear cached member list after fetching joined members") + } + for userID, member := range resp.Joined { + updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ + Membership: event.MembershipJoin, + AvatarURL: id.ContentURIString(member.AvatarURL), + Displayname: member.DisplayName, + }) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Stringer("user_id", userID). + Msg("Failed to update membership in state store after fetching joined members") + } } } return @@ -2252,20 +1798,20 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb } } if err == nil && cli.StateStore != nil { - var onlyMemberships []event.Membership + var clearMemberships []event.Membership if extra.Membership != "" { - onlyMemberships = []event.Membership{extra.Membership} - } else if extra.NotMembership != "" { - onlyMemberships = []event.Membership{event.MembershipJoin, event.MembershipLeave, event.MembershipInvite, event.MembershipBan, event.MembershipKnock} - onlyMemberships = slices.DeleteFunc(onlyMemberships, func(m event.Membership) bool { - return m == extra.NotMembership - }) + clearMemberships = append(clearMemberships, extra.Membership) } - updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, resp.Chunk, onlyMemberships...) - if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(updateErr). - Stringer("room_id", roomID). - Msg("Failed to update members in state store after fetching members") + if extra.NotMembership == "" { + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") + } + } + for _, evt := range resp.Chunk { + UpdateStateStore(ctx, cli.StateStore, evt) } } return @@ -2281,12 +1827,6 @@ func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err return } -func (cli *Client) PublicRooms(ctx context.Context, req *ReqPublicRooms) (resp *RespPublicRooms, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "publicRooms"}, req.Query()) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) - return -} - // Hierarchy returns a list of rooms that are in the room's hierarchy. See https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // The hierarchy API is provided to walk the space tree and discover the rooms with their aesthetic details. works in a depth-first manner: @@ -2301,10 +1841,11 @@ func (cli *Client) Hierarchy(ctx context.Context, roomID id.RoomID, req *ReqHier // Messages returns a list of message and state events for a room. It uses // pagination query parameters to paginate history in the room. -// See https://spec.matrix.org/v1.12/client-server-api/#get_matrixclientv3roomsroomidmessages +// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { query := map[string]string{ - "dir": string(dir), + "from": from, + "dir": string(dir), } if filter != nil { filterJSON, err := json.Marshal(filter) @@ -2313,9 +1854,6 @@ func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to stri } query["filter"] = string(filterJSON) } - if from != "" { - query["from"] = from - } if to != "" { query["to"] = to } @@ -2369,20 +1907,6 @@ func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev return } -func (cli *Client) GetUnredactedEventContent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "event", eventID}, map[string]string{ - "fi.mau.msc2815.include_unredacted_content": "true", - }) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) - return -} - -func (cli *Client) GetRelations(ctx context.Context, roomID id.RoomID, eventID id.EventID, req *ReqGetRelations) (resp *RespGetRelations, err error) { - urlPath := cli.BuildURLWithQuery(append(ClientURLPath{"v1", "rooms", roomID, "relations", eventID}, req.PathSuffix()...), req.Query()) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) - return -} - func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } @@ -2687,15 +2211,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req return err } -func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } @@ -2704,7 +2228,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error { +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), @@ -2774,73 +2298,24 @@ func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules return err } -func (cli *Client) ReportEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, reason string) error { - urlPath := cli.BuildClientURL("v3", "rooms", roomID, "report", eventID) - _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqReport{Reason: reason, Score: -100}, nil) - return err -} - -func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason string) error { - urlPath := cli.BuildClientURL("v3", "rooms", roomID, "report") - _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqReport{Reason: reason, Score: -100}, nil) - return err -} - -// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin. +// BatchSend sends a batch of historical events into a room. This is only available for appservices. // -// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid -func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) { - urlPath := cli.BuildClientURL("v3", "admin", "whois", userID) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) - return -} - -func (cli *Client) makeMSC4323URL(action string, target id.UserID) string { - if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) { - return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target) - } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) { - return cli.BuildClientURL("v1", "admin", action, target) +// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. +func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { + path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} + query := map[string]string{ + "prev_event_id": req.PrevEventID.String(), } - return "" -} - -// GetSuspendedStatus uses MSC4323 to check if a user is suspended. -func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { - urlPath := cli.makeMSC4323URL("suspend", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + if req.BeeperNewMessages { + query["com.beeper.new_messages"] = "true" } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) - return -} - -// GetLockStatus uses MSC4323 to check if a user is locked. -func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { - urlPath := cli.makeMSC4323URL("lock", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + if req.BeeperMarkReadBy != "" { + query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String() } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) - return -} - -// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended. -func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { - urlPath := cli.makeMSC4323URL("suspend", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + if len(req.BatchID) > 0 { + query["batch_id"] = req.BatchID.String() } - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) - return -} - -// SetLockStatus uses MSC4323 to set whether a user account is locked. -func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { - urlPath := cli.makeMSC4323URL("lock", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") - } - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp) return } @@ -2881,9 +2356,6 @@ func (cli *Client) BeeperDeleteRoom(ctx context.Context, roomID id.RoomID) (err // TxnID returns the next transaction ID. func (cli *Client) TxnID() string { - if cli == nil { - return "client is nil" - } txnID := atomic.AddInt32(&cli.txnID, 1) return fmt.Sprintf("mautrix-go_%d_%d", time.Now().UnixNano(), txnID) } diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go deleted file mode 100644 index c2846427..00000000 --- a/client_ephemeral_test.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mautrix_test - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) { - roomID := id.RoomID("!room:example.com") - evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} - txnID := "txn-123" - - var gotPath string - var gotQueryTS string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - gotQueryTS = r.URL.Query().Get("ts") - assert.Equal(t, http.MethodPut, r.Method) - _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) - })) - defer ts.Close() - - cli, err := mautrix.NewClient(ts.URL, "", "") - require.NoError(t, err) - - _, err = cli.BeeperSendEphemeralEvent( - context.Background(), - roomID, - evtType, - map[string]any{"foo": "bar"}, - mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234}, - ) - require.NoError(t, err) - - assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/")) - assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID)) - assert.Equal(t, "1234", gotQueryTS) -} - -func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`)) - })) - defer ts.Close() - - cli, err := mautrix.NewClient(ts.URL, "", "") - require.NoError(t, err) - - _, err = cli.BeeperSendEphemeralEvent( - context.Background(), - id.RoomID("!room:example.com"), - event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}, - map[string]any{"foo": "bar"}, - ) - require.Error(t, err) - assert.True(t, errors.Is(err, mautrix.MUnrecognized)) -} - -func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) { - roomID := id.RoomID("!room:example.com") - evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} - txnID := "txn-encrypted" - - stateStore := mautrix.NewMemoryStateStore() - err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{ - Algorithm: id.AlgorithmMegolmV1, - }) - require.NoError(t, err) - - fakeCrypto := &fakeCryptoHelper{ - encryptedContent: &event.EncryptedEventContent{ - Algorithm: id.AlgorithmMegolmV1, - MegolmCiphertext: []byte("ciphertext"), - }, - } - - var gotPath string - var gotBody map[string]any - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - assert.Equal(t, http.MethodPut, r.Method) - err := json.NewDecoder(r.Body).Decode(&gotBody) - require.NoError(t, err) - _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) - })) - defer ts.Close() - - cli, err := mautrix.NewClient(ts.URL, "", "") - require.NoError(t, err) - cli.StateStore = stateStore - cli.Crypto = fakeCrypto - - _, err = cli.BeeperSendEphemeralEvent( - context.Background(), - roomID, - evtType, - map[string]any{"foo": "bar"}, - mautrix.ReqSendEvent{TransactionID: txnID}, - ) - require.NoError(t, err) - - assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID)) - assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"]) - assert.Equal(t, 1, fakeCrypto.encryptCalls) - assert.Equal(t, roomID, fakeCrypto.lastRoomID) - assert.Equal(t, evtType, fakeCrypto.lastEventType) -} - -type fakeCryptoHelper struct { - encryptCalls int - lastRoomID id.RoomID - lastEventType event.Type - lastEncryptInput any - encryptedContent *event.EncryptedEventContent -} - -func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) { - f.encryptCalls++ - f.lastRoomID = roomID - f.lastEventType = eventType - f.lastEncryptInput = content - return f.encryptedContent, nil -} - -func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) { - return nil, nil -} - -func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool { - return false -} - -func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) { -} - -func (f *fakeCryptoHelper) Init(context.Context) error { - return nil -} diff --git a/commands/container.go b/commands/container.go deleted file mode 100644 index 9b909b75..00000000 --- a/commands/container.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "fmt" - "slices" - "strings" - "sync" - - "go.mau.fi/util/exmaps" - - "maunium.net/go/mautrix/event/cmdschema" -) - -type CommandContainer[MetaType any] struct { - commands map[string]*Handler[MetaType] - aliases map[string]string - lock sync.RWMutex - parent *Handler[MetaType] -} - -func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { - return &CommandContainer[MetaType]{ - commands: make(map[string]*Handler[MetaType]), - aliases: make(map[string]string), - } -} - -func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent { - data := make(exmaps.Set[*Handler[MetaType]]) - cont.collectHandlers(data) - specs := make([]*cmdschema.EventContent, 0, data.Size()) - for handler := range data.Iter() { - if handler.Parameters != nil { - specs = append(specs, handler.Spec()) - } - } - return specs -} - -func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) { - cont.lock.RLock() - defer cont.lock.RUnlock() - for _, handler := range cont.commands { - into.Add(handler) - if handler.subcommandContainer != nil { - handler.subcommandContainer.collectHandlers(into) - } - } -} - -// Register registers the given command handlers. -func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { - if cont == nil { - return - } - cont.lock.Lock() - defer cont.lock.Unlock() - for i, handler := range handlers { - if handler == nil { - panic(fmt.Errorf("handler #%d is nil", i+1)) - } - cont.registerOne(handler) - } -} - -func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) { - if strings.ToLower(handler.Name) != handler.Name { - panic(fmt.Errorf("command %q is not lowercase", handler.Name)) - } else if val, alreadyExists := cont.commands[handler.Name]; alreadyExists && val != handler { - panic(fmt.Errorf("tried to register command %q, but it's already registered", handler.Name)) - } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { - panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) - } - if !slices.Contains(handler.parents, cont.parent) { - handler.parents = append(handler.parents, cont.parent) - handler.nestedNameCache = nil - } - cont.commands[handler.Name] = handler - for _, alias := range handler.Aliases { - if strings.ToLower(alias) != alias { - panic(fmt.Errorf("alias %q is not lowercase", alias)) - } else if val, alreadyExists := cont.aliases[alias]; alreadyExists && val != handler.Name { - panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered for %q", alias, handler.Name, cont.aliases[alias])) - } else if _, alreadyExists = cont.commands[alias]; alreadyExists { - panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered as a command", alias, handler.Name)) - } - cont.aliases[alias] = handler.Name - } - handler.initSubcommandContainer() -} - -func (cont *CommandContainer[MetaType]) Unregister(handlers ...*Handler[MetaType]) { - if cont == nil { - return - } - cont.lock.Lock() - defer cont.lock.Unlock() - for _, handler := range handlers { - cont.unregisterOne(handler) - } -} - -func (cont *CommandContainer[MetaType]) unregisterOne(handler *Handler[MetaType]) { - delete(cont.commands, handler.Name) - for _, alias := range handler.Aliases { - if cont.aliases[alias] == handler.Name { - delete(cont.aliases, alias) - } - } -} - -func (cont *CommandContainer[MetaType]) GetHandler(name string) *Handler[MetaType] { - if cont == nil { - return nil - } - cont.lock.RLock() - defer cont.lock.RUnlock() - alias, ok := cont.aliases[name] - if ok { - name = alias - } - handler, ok := cont.commands[name] - if !ok { - handler = cont.commands[UnknownCommandName] - } - return handler -} diff --git a/commands/event.go b/commands/event.go deleted file mode 100644 index 76d6c9f0..00000000 --- a/commands/event.go +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -// Event contains the data of a single command event. -// It also provides some helper methods for responding to the command. -type Event[MetaType any] struct { - *event.Event - // RawInput is the entire message before splitting into command and arguments. - RawInput string - // ParentCommands is the chain of commands leading up to this command. - // This is only set if the command is a subcommand. - ParentCommands []string - ParentHandlers []*Handler[MetaType] - // Command is the lowercased first word of the message. - Command string - // Args are the rest of the message split by whitespace ([strings.Fields]). - Args []string - // RawArgs is the same as args, but without the splitting by whitespace. - RawArgs string - - StructuredArgs json.RawMessage - - Ctx context.Context - Log *zerolog.Logger - Proc *Processor[MetaType] - Handler *Handler[MetaType] - Meta MetaType - - redactedBy id.EventID -} - -var IDHTMLParser = &format.HTMLParser{ - PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { - if len(mxid) == 0 { - return displayname - } - if eventID != "" { - return fmt.Sprintf("https://matrix.to/#/%s/%s", mxid, eventID) - } - return mxid - }, - ItalicConverter: func(s string, c format.Context) string { - return fmt.Sprintf("*%s*", s) - }, - Newline: "\n", -} - -// ParseEvent parses a message into a command event struct. -func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] { - content, ok := evt.Content.Parsed.(*event.MessageEventContent) - if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { - return nil - } - text := content.Body - if content.Format == event.FormatHTML { - text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) - } - if content.MSC4391BotCommand != nil { - if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 { - return nil - } - wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand) - wrapped.RawInput = text - return wrapped - } - if len(text) == 0 { - return nil - } - return RawTextToEvent[MetaType](ctx, evt, text) -} - -func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] { - commandParts := strings.Split(content.Command, " ") - return &Event[MetaType]{ - Event: evt, - // Fake a command and args to let the subcommand finder in Process work. - Command: commandParts[0], - Args: commandParts[1:], - Ctx: ctx, - Log: zerolog.Ctx(ctx), - - StructuredArgs: content.Arguments, - } -} - -func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { - parts := strings.Fields(text) - if len(parts) == 0 { - parts = []string{""} - } - return &Event[MetaType]{ - Event: evt, - RawInput: text, - Command: strings.ToLower(parts[0]), - Args: parts[1:], - RawArgs: strings.TrimLeft(strings.TrimPrefix(text, parts[0]), " "), - Log: zerolog.Ctx(ctx), - Ctx: ctx, - } -} - -type ReplyOpts struct { - AllowHTML bool - AllowMarkdown bool - Reply bool - Thread bool - SendAsText bool - Edit id.EventID - OverrideMentions *event.Mentions - Extra map[string]any -} - -func (evt *Event[MetaType]) Reply(msg string, args ...any) id.EventID { - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - return evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true}) -} - -func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) id.EventID { - content := format.RenderMarkdown(msg, opts.AllowMarkdown, opts.AllowHTML) - if opts.Thread { - content.SetThread(evt.Event) - } - if opts.Reply { - content.SetReply(evt.Event) - } - if !opts.SendAsText { - content.MsgType = event.MsgNotice - } - if opts.Edit != "" { - content.SetEdit(opts.Edit) - } - if opts.OverrideMentions != nil { - content.Mentions = opts.OverrideMentions - } - var wrapped any = &content - if opts.Extra != nil { - wrapped = &event.Content{ - Parsed: &content, - Raw: opts.Extra, - } - } - resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, wrapped) - if err != nil { - zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") - return "" - } - return resp.EventID -} - -func (evt *Event[MetaType]) React(emoji string) id.EventID { - resp, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji) - if err != nil { - zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reaction") - return "" - } - return resp.EventID -} - -func (evt *Event[MetaType]) Redact() id.EventID { - if evt.redactedBy != "" { - return evt.redactedBy - } - resp, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) - if err != nil { - zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command") - return "" - } - evt.redactedBy = resp.EventID - return resp.EventID -} - -func (evt *Event[MetaType]) MarkRead() { - err := evt.Proc.Client.MarkRead(evt.Ctx, evt.RoomID, evt.ID) - if err != nil { - zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt") - } -} - -// ShiftArg removes the first argument from the Args list and RawArgs data and returns it. -// RawInput will not be modified. -func (evt *Event[MetaType]) ShiftArg() string { - if len(evt.Args) == 0 { - return "" - } - firstArg := evt.Args[0] - evt.RawArgs = strings.TrimLeft(strings.TrimPrefix(evt.RawArgs, evt.Args[0]), " ") - evt.Args = evt.Args[1:] - return firstArg -} - -// UnshiftArg reverses ShiftArg by adding the given value to the beginning of the Args list and RawArgs data. -func (evt *Event[MetaType]) UnshiftArg(arg string) { - evt.RawArgs = arg + " " + evt.RawArgs - evt.Args = append([]string{arg}, evt.Args...) -} - -func (evt *Event[MetaType]) ParseArgs(into any) error { - return json.Unmarshal(evt.StructuredArgs, into) -} - -func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) { - err = evt.ParseArgs(&into) - return -} - -func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) { - return func(evt *Event[MetaType]) { - parsed, err := ParseArgs[T, MetaType](evt) - if err != nil { - evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct") - // TODO better error, usage info? deduplicate with Process - evt.Reply("Failed to parse arguments: %v", err) - return - } - fn(evt, parsed) - } -} diff --git a/commands/handler.go b/commands/handler.go deleted file mode 100644 index 56f27f06..00000000 --- a/commands/handler.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "strings" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/event/cmdschema" -) - -type Handler[MetaType any] struct { - // Func is the function that is called when the command is executed. - Func func(ce *Event[MetaType]) - - // Name is the primary name of the command. It must be lowercase. - Name string - // Aliases are alternative names for the command. They must be lowercase. - Aliases []string - // Subcommands are subcommands of this command. - Subcommands []*Handler[MetaType] - // PreFunc is a function that is called before checking subcommands. - // It can be used to have parameters between subcommands (e.g. `!rooms `). - // Event.ShiftArg will likely be useful for implementing such parameters. - PreFunc func(ce *Event[MetaType]) - - // Description is a short description of the command. - Description *event.ExtensibleTextContainer - // Parameters is a description of structured command parameters. - // If set, the StructuredArgs field of Event will be populated. - Parameters []*cmdschema.Parameter - TailParam string - - parents []*Handler[MetaType] - nestedNameCache []string - subcommandContainer *CommandContainer[MetaType] -} - -func (h *Handler[MetaType]) NestedNames() []string { - if h.nestedNameCache != nil { - return h.nestedNameCache - } - nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents)) - for _, parent := range h.parents { - if parent == nil { - nestedNames = append(nestedNames, h.Name) - nestedNames = append(nestedNames, h.Aliases...) - } else { - for _, parentName := range parent.NestedNames() { - nestedNames = append(nestedNames, parentName+" "+h.Name) - for _, alias := range h.Aliases { - nestedNames = append(nestedNames, parentName+" "+alias) - } - } - } - } - h.nestedNameCache = nestedNames - return nestedNames -} - -func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { - names := h.NestedNames() - return &cmdschema.EventContent{ - Command: names[0], - Aliases: names[1:], - Parameters: h.Parameters, - Description: h.Description, - TailParam: h.TailParam, - } -} - -func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { - if h.Parameters == nil { - h.Parameters = other.Parameters - h.TailParam = other.TailParam - } - h.Func = other.Func -} - -func (h *Handler[MetaType]) initSubcommandContainer() { - if len(h.Subcommands) > 0 { - h.subcommandContainer = NewCommandContainer[MetaType]() - h.subcommandContainer.parent = h - h.subcommandContainer.Register(h.Subcommands...) - } else { - h.subcommandContainer = nil - } -} - -func MakeUnknownCommandHandler[MetaType any](prefix string) *Handler[MetaType] { - return &Handler[MetaType]{ - Name: UnknownCommandName, - Func: func(ce *Event[MetaType]) { - if len(ce.ParentCommands) == 0 { - ce.Reply("Unknown command `%s%s`", prefix, ce.Command) - } else { - ce.Reply("Unknown subcommand `%s%s %s`", prefix, strings.Join(ce.ParentCommands, " "), ce.Command) - } - }, - } -} diff --git a/commands/prevalidate.go b/commands/prevalidate.go deleted file mode 100644 index facca4da..00000000 --- a/commands/prevalidate.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "strings" -) - -// A PreValidator contains a function that takes an Event and returns true if the event should be processed further. -// -// The [PreValidator] field in [Processor] is called before the handler of the command is checked. -// It can be used to modify the command or arguments, or to skip the command entirely. -// -// The primary use case is removing a static command prefix, such as requiring all commands start with `!`. -type PreValidator[MetaType any] interface { - Validate(*Event[MetaType]) bool -} - -// FuncPreValidator is a simple function that implements the PreValidator interface. -type FuncPreValidator[MetaType any] func(*Event[MetaType]) bool - -func (f FuncPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { - return f(ce) -} - -// AllPreValidator can be used to combine multiple PreValidators, such that -// all of them must return true for the command to be processed further. -type AllPreValidator[MetaType any] []PreValidator[MetaType] - -func (f AllPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { - for _, validator := range f { - if !validator.Validate(ce) { - return false - } - } - return true -} - -// AnyPreValidator can be used to combine multiple PreValidators, such that -// at least one of them must return true for the command to be processed further. -type AnyPreValidator[MetaType any] []PreValidator[MetaType] - -func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { - for _, validator := range f { - if validator.Validate(ce) { - return true - } - } - return false -} - -// ValidatePrefixCommand checks that the first word in the input is exactly the given string, -// and if so, removes it from the command and sets the command to the next word. -// -// For example, `ValidateCommandPrefix("!mybot")` would only allow commands in the form `!mybot foo`, -// where `foo` would be used to look up the command handler. -func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] { - return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { - if ce.Command == prefix && len(ce.Args) > 0 { - ce.Command = strings.ToLower(ce.ShiftArg()) - return true - } - return false - }) -} - -// ValidatePrefixSubstring checks that the command starts with the given prefix, -// and if so, removes it from the command. -// -// For example, `ValidatePrefixSubstring("!")` would only allow commands in the form `!foo`, -// where `foo` would be used to look up the command handler. -func ValidatePrefixSubstring[MetaType any](prefix string) PreValidator[MetaType] { - return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { - if strings.HasPrefix(ce.Command, prefix) { - ce.Command = ce.Command[len(prefix):] - return true - } - return false - }) -} diff --git a/commands/processor.go b/commands/processor.go deleted file mode 100644 index 80f6745d..00000000 --- a/commands/processor.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "runtime/debug" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" -) - -// Processor implements boilerplate code for splitting messages into a command and arguments, -// and finding the appropriate handler for the command. -type Processor[MetaType any] struct { - *CommandContainer[MetaType] - - Client *mautrix.Client - LogArgs bool - PreValidator PreValidator[MetaType] - Meta MetaType - - ReactionCommandPrefix string -} - -// UnknownCommandName is the name of the fallback handler which is used if no other handler is found. -// If even the unknown command handler is not found, the command is ignored. -const UnknownCommandName = "__unknown-command__" - -func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { - proc := &Processor[MetaType]{ - CommandContainer: NewCommandContainer[MetaType](), - Client: cli, - PreValidator: ValidatePrefixSubstring[MetaType]("!"), - } - proc.Register(MakeUnknownCommandHandler[MetaType]("!")) - return proc -} - -func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { - log := zerolog.Ctx(ctx).With(). - Stringer("sender", evt.Sender). - Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.ID). - Logger() - defer func() { - panicErr := recover() - if panicErr != nil { - logEvt := log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()) - if realErr, ok := panicErr.(error); ok { - logEvt = logEvt.Err(realErr) - } else { - logEvt = logEvt.Any(zerolog.ErrorFieldName, panicErr) - } - logEvt.Msg("Panic in command handler") - _, err := proc.Client.SendReaction(ctx, evt.RoomID, evt.ID, "💥") - if err != nil { - log.Err(err).Msg("Failed to send reaction after panic") - } - } - }() - var parsed *Event[MetaType] - switch evt.Type { - case event.EventReaction: - parsed = proc.ParseReaction(ctx, evt) - case event.EventMessage: - parsed = proc.ParseEvent(ctx, evt) - } - if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) { - return - } - parsed.Proc = proc - parsed.Meta = proc.Meta - parsed.Ctx = ctx - - handler := proc.GetHandler(parsed.Command) - if handler == nil { - return - } - parsed.Handler = handler - if handler.PreFunc != nil { - handler.PreFunc(parsed) - } - handlerChain := zerolog.Arr() - handlerChain.Str(handler.Name) - for handler.subcommandContainer != nil && len(parsed.Args) > 0 { - subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0])) - if subHandler != nil { - parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command) - parsed.ParentHandlers = append(parsed.ParentHandlers, handler) - handler = subHandler - handlerChain.Str(subHandler.Name) - parsed.Command = strings.ToLower(parsed.ShiftArg()) - parsed.Handler = subHandler - if subHandler.PreFunc != nil { - subHandler.PreFunc(parsed) - } - } else { - break - } - } - if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { - // TODO allow unknown command handlers to be called? - // The client sent MSC4391 data, but the target command wasn't found - log.Debug().Msg("Didn't find handler for MSC4391 command") - return - } - - logWith := log.With(). - Str("command", parsed.Command). - Array("handler", handlerChain) - if len(parsed.ParentCommands) > 0 { - logWith = logWith.Strs("parent_commands", parsed.ParentCommands) - } - if proc.LogArgs { - logWith = logWith.Strs("args", parsed.Args) - if parsed.StructuredArgs != nil { - logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs) - } - } - log = logWith.Logger() - parsed.Ctx = log.WithContext(ctx) - parsed.Log = &log - - if handler.Parameters != nil && parsed.StructuredArgs == nil { - // The handler wants structured parameters, but the client didn't send MSC4391 data - var err error - parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs) - if err != nil { - log.Debug().Err(err).Msg("Failed to parse structured arguments") - // TODO better error, usage info? deduplicate with WithParsedArgs - parsed.Reply("Failed to parse arguments: %v", err) - return - } - if proc.LogArgs { - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.RawJSON("structured_args", parsed.StructuredArgs) - }) - } - } - - log.Debug().Msg("Processing command") - handler.Func(parsed) -} diff --git a/commands/reactions.go b/commands/reactions.go deleted file mode 100644 index 0d316219..00000000 --- a/commands/reactions.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "encoding/json" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" -) - -const ReactionCommandsKey = "fi.mau.reaction_commands" -const ReactionMultiUseKey = "fi.mau.reaction_multi_use" - -type ReactionCommandData struct { - Command string `json:"command"` - Args any `json:"args,omitempty"` -} - -func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { - content, ok := evt.Content.Parsed.(*event.ReactionEventContent) - if !ok { - return nil - } - evtID := content.RelatesTo.EventID - if evtID == "" || !strings.HasPrefix(content.RelatesTo.Key, proc.ReactionCommandPrefix) { - return nil - } - targetEvt, err := proc.Client.GetEvent(ctx, evt.RoomID, evtID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", evtID).Msg("Failed to get target event for reaction") - return nil - } else if targetEvt.Sender != proc.Client.UserID || targetEvt.Unsigned.RedactedBecause != nil { - return nil - } - if targetEvt.Type == event.EventEncrypted { - if proc.Client.Crypto == nil { - zerolog.Ctx(ctx).Warn(). - Stringer("target_event_id", evtID). - Msg("Received reaction to encrypted event, but don't have crypto helper in client") - return nil - } - _ = targetEvt.Content.ParseRaw(targetEvt.Type) - targetEvt, err = proc.Client.Crypto.Decrypt(ctx, targetEvt) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("target_event_id", evtID). - Msg("Failed to decrypt target event for reaction") - return nil - } - } - reactionCommands, ok := targetEvt.Content.Raw[ReactionCommandsKey].(map[string]any) - if !ok { - zerolog.Ctx(ctx).Trace(). - Stringer("target_event_id", evtID). - Msg("Reaction target event doesn't have commands key") - return nil - } - isMultiUse, _ := targetEvt.Content.Raw[ReactionMultiUseKey].(bool) - rawCmd, ok := reactionCommands[content.RelatesTo.Key] - if !ok { - zerolog.Ctx(ctx).Debug(). - Stringer("target_event_id", evtID). - Str("reaction_key", content.RelatesTo.Key). - Msg("Reaction command not found in target event") - return nil - } - var wrappedEvt *Event[MetaType] - switch typedCmd := rawCmd.(type) { - case string: - wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd) - case map[string]any: - var input event.MSC4391BotCommandInput - if marshaled, err := json.Marshal(typedCmd); err != nil { - - } else if err = json.Unmarshal(marshaled, &input); err != nil { - - } else { - wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input) - } - } - if wrappedEvt == nil { - zerolog.Ctx(ctx).Debug(). - Stringer("target_event_id", evtID). - Str("reaction_key", content.RelatesTo.Key). - Msg("Reaction command data is invalid") - return nil - } - wrappedEvt.Proc = proc - wrappedEvt.Redact() - if !isMultiUse { - DeleteAllReactions(ctx, proc.Client, evt) - } - if wrappedEvt.Command == "" { - return nil - } - return wrappedEvt -} - -func DeleteAllReactionsCommandFunc[MetaType any](ce *Event[MetaType]) { - DeleteAllReactions(ce.Ctx, ce.Proc.Client, ce.Event) -} - -func DeleteAllReactions(ctx context.Context, client *mautrix.Client, evt *event.Event) { - rel, ok := evt.Content.Parsed.(event.Relatable) - if !ok { - return - } - relation := rel.OptionalGetRelatesTo() - if relation == nil { - return - } - targetEvt := relation.GetReplyTo() - if targetEvt == "" { - targetEvt = relation.GetAnnotationID() - } - if targetEvt == "" { - return - } - relations, err := client.GetRelations(ctx, evt.RoomID, targetEvt, &mautrix.ReqGetRelations{ - RelationType: event.RelAnnotation, - EventType: event.EventReaction, - Limit: 20, - }) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get reactions to delete") - return - } - for _, relEvt := range relations.Chunk { - _, err = client.RedactEvent(ctx, relEvt.RoomID, relEvt.ID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("event_id", relEvt.ID).Msg("Failed to redact reaction event") - } - } -} diff --git a/crypto/account.go b/crypto/account.go index 0bd09ecf..2f012e59 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -7,12 +7,7 @@ package crypto import ( - "encoding/json" - - "github.com/tidwall/sjson" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" @@ -27,61 +22,32 @@ type OlmAccount struct { } func NewOlmAccount() *OlmAccount { - account, err := olm.NewAccount() - if err != nil { - panic(err) - } return &OlmAccount{ - Internal: account, + Internal: *olm.NewAccount(), } } func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) { if len(account.signingKey) == 0 || len(account.identityKey) == 0 { - var err error - account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() - if err != nil { - panic(err) - } + account.signingKey, account.identityKey = account.Internal.IdentityKeys() } return account.signingKey, account.identityKey } func (account *OlmAccount) SigningKey() id.SigningKey { if len(account.signingKey) == 0 { - var err error - account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() - if err != nil { - panic(err) - } + account.signingKey, account.identityKey = account.Internal.IdentityKeys() } return account.signingKey } func (account *OlmAccount) IdentityKey() id.IdentityKey { if len(account.identityKey) == 0 { - var err error - account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() - if err != nil { - panic(err) - } + account.signingKey, account.identityKey = account.Internal.IdentityKeys() } return account.identityKey } -// SignJSON signs the given JSON object following the Matrix specification: -// https://matrix.org/docs/spec/appendices#signing-json -func (account *OlmAccount) SignJSON(obj any) (string, error) { - objJSON, err := json.Marshal(obj) - if err != nil { - return "", err - } - objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") - objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - signed, err := account.Internal.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - return string(signed), err -} - func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID) *mautrix.DeviceKeys { deviceKeys := &mautrix.DeviceKeys{ UserID: userID, @@ -93,7 +59,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID }, } - signature, err := account.SignJSON(deviceKeys) + signature, err := account.Internal.SignJSON(deviceKeys) if err != nil { panic(err) } @@ -108,13 +74,9 @@ func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID account.Internal.GenOneTimeKeys(uint(newCount)) } oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey) - internalKeys, err := account.Internal.OneTimeKeys() - if err != nil { - panic(err) - } - for keyID, key := range internalKeys { + for keyID, key := range account.Internal.OneTimeKeys() { key := mautrix.OneTimeKey{Key: key} - signature, _ := account.SignJSON(key) + signature, _ := account.Internal.SignJSON(key) key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go index d6611dc9..bb03f706 100644 --- a/crypto/aescbc/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -7,13 +7,11 @@ package aescbc_test import ( + "bytes" "crypto/aes" "crypto/rand" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "maunium.net/go/mautrix/crypto/aescbc" ) @@ -24,23 +22,32 @@ func TestAESCBC(t *testing.T) { // The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256) key := make([]byte, 32) _, err = rand.Read(key) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } iv := make([]byte, aes.BlockSize) _, err = rand.Read(iv) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } plaintext = []byte("secret message for testing") //increase to next block size for len(plaintext)%8 != 0 { plaintext = append(plaintext, []byte("-")...) } - ciphertext, err = aescbc.Encrypt(key, iv, plaintext) - require.NoError(t, err) + if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil { + t.Fatal(err) + } resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } - assert.Equal(t, string(resultPlainText), string(plaintext)) + if string(resultPlainText) != string(plaintext) { + t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext) + } } func TestAESCBCCase1(t *testing.T) { @@ -54,10 +61,18 @@ func TestAESCBCCase1(t *testing.T) { key := make([]byte, 32) iv := make([]byte, aes.BlockSize) encrypted, err := aescbc.Encrypt(key, iv, input) - require.NoError(t, err) - assert.Equal(t, expected, encrypted, "encrypted output does not match expected") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected, encrypted) { + t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) + } decrypted, err := aescbc.Decrypt(key, iv, encrypted) - require.NoError(t, err) - assert.Equal(t, input, decrypted, "decrypted output does not match input") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(input, decrypted) { + t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input) + } } diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 727aacbf..344db4f0 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -9,7 +9,6 @@ package attachment import ( "crypto/aes" "crypto/cipher" - "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" @@ -21,24 +20,13 @@ import ( ) var ( - ErrHashMismatch = errors.New("mismatching SHA-256 digest") - ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") - ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - ErrInvalidKey = errors.New("failed to decode key") - ErrInvalidInitVector = errors.New("failed to decode initialization vector") - ErrInvalidHash = errors.New("failed to decode SHA-256 hash") - ErrReaderClosed = errors.New("encrypting reader was already closed") -) - -// Deprecated: use variables prefixed with Err -var ( - HashMismatch = ErrHashMismatch - UnsupportedVersion = ErrUnsupportedVersion - UnsupportedAlgorithm = ErrUnsupportedAlgorithm - InvalidKey = ErrInvalidKey - InvalidInitVector = ErrInvalidInitVector - InvalidHash = ErrInvalidHash - ReaderClosed = ErrReaderClosed + HashMismatch = errors.New("mismatching SHA-256 digest") + UnsupportedVersion = errors.New("unsupported Matrix file encryption version") + UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + InvalidKey = errors.New("failed to decode key") + InvalidInitVector = errors.New("failed to decode initialization vector") + InvalidHash = errors.New("failed to decode SHA-256 hash") + ReaderClosed = errors.New("encrypting reader was already closed") ) var ( @@ -96,25 +84,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return ErrInvalidKey + return InvalidKey } else if len(ef.InitVector) != ivBase64Length { - return ErrInvalidInitVector + return InvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return ErrInvalidHash + return InvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return ErrInvalidKey + return InvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return ErrInvalidInitVector + return InvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return ErrInvalidHash + return InvalidHash } } return nil @@ -139,43 +127,6 @@ func (ef *EncryptedFile) EncryptInPlace(data []byte) { ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:]) } -type ReadWriterAt interface { - io.WriterAt - io.Reader -} - -// EncryptFile encrypts the given file in-place and updates the SHA256 hash in the EncryptedFile struct. -func (ef *EncryptedFile) EncryptFile(file ReadWriterAt) error { - err := ef.decodeKeys(false) - if err != nil { - return err - } - block, _ := aes.NewCipher(ef.decoded.key[:]) - stream := cipher.NewCTR(block, ef.decoded.iv[:]) - hasher := sha256.New() - buf := make([]byte, 32*1024) - var writePtr int64 - var n int - for { - n, err = file.Read(buf) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - if n == 0 { - break - } - stream.XORKeyStream(buf[:n], buf[:n]) - _, err = file.WriteAt(buf[:n], writePtr) - if err != nil { - return err - } - writePtr += int64(n) - hasher.Write(buf[:n]) - } - ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) - return nil -} - type encryptingReader struct { stream cipher.Stream hash hash.Hash @@ -190,7 +141,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil) func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { - return 0, ErrReaderClosed + return 0, ReaderClosed } if offset != 0 || whence != io.SeekStart { return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") @@ -211,20 +162,15 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ErrReaderClosed + return 0, ReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return } } n, err = r.source.Read(dst) - if r.isDecrypting { - r.hash.Write(dst[:n]) - } r.stream.XORKeyStream(dst[:n], dst[:n]) - if !r.isDecrypting { - r.hash.Write(dst[:n]) - } + r.hash.Write(dst[:n]) return } @@ -234,8 +180,10 @@ func (r *encryptingReader) Close() (err error) { err = closer.Close() } if r.isDecrypting { - if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { - return ErrHashMismatch + var downloadedChecksum [utils.SHAHashLength]byte + r.hash.Sum(downloadedChecksum[:]) + if downloadedChecksum != r.file.decoded.sha256 { + return HashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -276,9 +224,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return ErrUnsupportedVersion + return UnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return ErrUnsupportedAlgorithm + return UnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -289,13 +237,12 @@ func (ef *EncryptedFile) PrepareForDecryption() error { func (ef *EncryptedFile) DecryptInPlace(data []byte) error { if err := ef.PrepareForDecryption(); err != nil { return err + } else if ef.decoded.sha256 != sha256.Sum256(data) { + return HashMismatch + } else { + utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) + return nil } - dataHash := sha256.Sum256(data) - if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { - return ErrHashMismatch - } - utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) - return nil } // DecryptStream wraps the given io.Reader in order to decrypt the data. @@ -308,10 +255,9 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser { block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ - isDecrypting: true, - stream: cipher.NewCTR(block, ef.decoded.iv[:]), - hash: sha256.New(), - source: reader, - file: ef, + stream: cipher.NewCTR(block, ef.decoded.iv[:]), + hash: sha256.New(), + source: reader, + file: ef, } } diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index 9fe929ab..d7f1394a 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrUnsupportedVersion) + assert.ErrorIs(t, err, UnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) + assert.ErrorIs(t, err, UnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrHashMismatch) + assert.ErrorIs(t, err, HashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrInvalidHash) + assert.ErrorIs(t, err, InvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrInvalidHash) + assert.ErrorIs(t, err, InvalidHash) } diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index 25250178..ec551dbe 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -68,10 +68,6 @@ func calculateCompatMAC(macKey []byte) []byte { // // [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) { - return EncryptSessionDataWithPubkey(backupKey.PublicKey(), sessionData) -} - -func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) (*EncryptedSessionData[T], error) { sessionJSON, err := json.Marshal(sessionData) if err != nil { return nil, err @@ -82,7 +78,7 @@ func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) return nil, err } - sharedSecret, err := ephemeralKey.ECDH(pubkey) + sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey()) if err != nil { return nil, err } diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go index 36476aa4..d1a7f0a5 100644 --- a/crypto/canonicaljson/json_test.go +++ b/crypto/canonicaljson/json_test.go @@ -17,43 +17,31 @@ package canonicaljson import ( "testing" - - "github.com/stretchr/testify/assert" ) -func TestSortJSON(t *testing.T) { - var tests = []struct { - input string - want string - }{ - {"{}", "{}"}, - {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, - {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, - {`[true,false,null]`, `[true,false,null]`}, - {`[9007199254740991]`, `[9007199254740991]`}, - {"\t\n[9007199254740991]", `[9007199254740991]`}, - {`[true,false,null]`, `[true,false,null]`}, - {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, - {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, - {`[true,false,null]`, `[true,false,null]`}, - {`[9007199254740991]`, `[9007199254740991]`}, - {"\t\n[9007199254740991]", `[9007199254740991]`}, - {`[true,false,null]`, `[true,false,null]`}, - } - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - got := SortJSON([]byte(test.input), nil) +func testSortJSON(t *testing.T, input, want string) { + got := SortJSON([]byte(input), nil) - // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. - assert.EqualValues(t, test.want, string(CompactJSON(got, nil))) - }) + // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. + if string(CompactJSON(got, nil)) != want { + t.Errorf("SortJSON(%q): want %q got %q", input, want, got) } } +func TestSortJSON(t *testing.T) { + testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`) + testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, + `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`) + testSortJSON(t, `[true,false,null]`, `[true,false,null]`) + testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`) + testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`) +} + func testCompactJSON(t *testing.T, input, want string) { - t.Helper() got := string(CompactJSON([]byte(input), nil)) - assert.EqualValues(t, want, got) + if got != want { + t.Errorf("CompactJSON(%q): want %q got %q", input, want, got) + } } func TestCompactJSON(t *testing.T) { @@ -86,23 +74,18 @@ func TestCompactJSON(t *testing.T) { testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`) } -func TestReadHex(t *testing.T) { - tests := []struct { - input string - want uint32 - }{ - - {"0123", 0x0123}, - {"4567", 0x4567}, - {"89AB", 0x89AB}, - {"CDEF", 0xCDEF}, - {"89ab", 0x89AB}, - {"cdef", 0xCDEF}, - } - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - got := readHexDigits([]byte(test.input)) - assert.Equal(t, test.want, got) - }) +func testReadHex(t *testing.T, input string, want uint32) { + got := readHexDigits([]byte(input)) + if want != got { + t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got) } } + +func TestReadHex(t *testing.T) { + testReadHex(t, "0123", 0x0123) + testReadHex(t, "4567", 0x4567) + testReadHex(t, "89AB", 0x89AB) + testReadHex(t, "CDEF", 0xCDEF) + testReadHex(t, "89ab", 0x89AB) + testReadHex(t, "cdef", 0xCDEF) +} diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 5d9bf5b3..3d01fb99 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -11,8 +11,6 @@ import ( "context" "fmt" - "go.mau.fi/util/jsonbytes" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" @@ -35,9 +33,9 @@ func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { } type CrossSigningSeeds struct { - MasterKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.master"` - SelfSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.self_signing"` - UserSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.user_signing"` + MasterKey []byte + SelfSigningKey []byte + UserSigningKey []byte } func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { @@ -103,7 +101,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross masterKeyID: keys.MasterKey.PublicKey(), }, } - masterSig, err := mach.account.SignJSON(masterKey) + masterSig, err := mach.account.Internal.SignJSON(masterKey) if err != nil { return fmt.Errorf("failed to sign master key: %w", err) } @@ -135,7 +133,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross } userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) - err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 223fc7b5..77efab5b 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -20,20 +20,6 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } -func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) { - pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) - if pubkeys != nil { - hasKeys = true - isVerified, err = mach.CryptoStore.IsKeySignedBy( - ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey, - ) - if err != nil { - err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) - } - } - return -} - func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys @@ -63,8 +49,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning := dbKeys[id.XSUsageSelfSigning] - userSigning := dbKeys[id.XSUsageUserSigning] + selfSigning, _ := dbKeys[id.XSUsageSelfSigning] + userSigning, _ := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index ae3d1eb1..1d80cc91 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -87,7 +87,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()): masterKey.String(), }, } - signature, err := mach.account.SignJSON(masterKeyObj) + signature, err := mach.account.Internal.SignJSON(masterKeyObj) if err != nil { return fmt.Errorf("failed to sign JSON: %w", err) } diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index fd42880d..389a9fd2 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -8,7 +8,6 @@ package crypto import ( "context" - "errors" "fmt" "maunium.net/go/mautrix" @@ -72,46 +71,6 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex }, passphrase) } -func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error { - keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx) - if err != nil { - return fmt.Errorf("failed to get default SSSS key data: %w", err) - } - key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) - if errors.Is(err, ssss.ErrUnverifiableKey) { - mach.machOrContextLog(ctx).Warn(). - Str("key_id", keyID). - Msg("SSSS key is unverifiable, trying to use without verifying") - } else if err != nil { - return err - } - err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) - if err != nil { - return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) - } - err = mach.SignOwnDevice(ctx, mach.OwnIdentity()) - if err != nil { - return fmt.Errorf("failed to sign own device: %w", err) - } - err = mach.SignOwnMasterKey(ctx) - if err != nil { - return fmt.Errorf("failed to sign own master key: %w", err) - } - return nil -} - -func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) { - recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - if err != nil { - err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err) - } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil { - err = fmt.Errorf("failed to sign own device: %w", err) - } else if err = mach.SignOwnMasterKey(ctx); err != nil { - err = fmt.Errorf("failed to sign own master key: %w", err) - } - return -} - // GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. // // A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key @@ -138,12 +97,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u // Publish cross-signing keys err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback) if err != nil { - return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err) + return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { - return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) + return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } return key.RecoveryKey(), keysCache, nil diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 57406b11..968a52a1 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -20,36 +20,38 @@ import ( func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { - log := log.With().Stringer("user_id", userID).Logger() + log := log.With().Str("user_id", userID.String()).Logger() currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - for curKeyUsage, curKey := range currentKeys { - log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") + if currentKeys != nil { + for curKeyUsage, curKey := range currentKeys { + log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") + } } + break } - break } } } for _, key := range userKeys.Keys { - log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() + log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { - log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key") + log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { log.Error().Err(err).Msg("Error storing cross-signing key") } @@ -75,16 +77,16 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } } if len(signingKey) != 43 { - log.Trace().Msg("Cross-signing key has a signature from an unknown key") + log.Debug().Msg("Cross-signing key has a signature from an unknown key") continue } - log.Trace().Msg("Verifying cross-signing key signature") + log.Debug().Msg("Verifying cross-signing key signature") if verified, err := signatures.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { log.Warn().Err(err).Msg("Error verifying cross-signing key signature") } else { if verified { - log.Trace().Err(err).Msg("Cross-signing key signature verified") + log.Debug().Err(err).Msg("Cross-signing key signature verified") err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature) if err != nil { log.Error().Err(err).Msg("Error storing cross-signing key signature") diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index b70370a2..e11fb018 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -13,8 +13,6 @@ import ( "testing" "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" @@ -26,12 +24,17 @@ var noopLogger = zerolog.Nop() func getOlmMachine(t *testing.T) *OlmMachine { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - require.NoError(t, err, "Error opening raw database") + if err != nil { + t.Fatalf("Error opening db: %v", err) + } db, err := dbutil.NewWithDB(rawDB, "sqlite3") - require.NoError(t, err, "Error creating database wrapper") + if err != nil { + t.Fatalf("Error opening db: %v", err) + } sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - err = sqlStore.DB.Upgrade(context.TODO()) - require.NoError(t, err, "Error upgrading database") + if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { + t.Fatalf("Error creating tables: %v", err) + } userID := id.UserID("@mautrix") mk, _ := olm.NewPKSigning() @@ -63,25 +66,29 @@ func TestTrustOwnDevice(t *testing.T) { DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } - assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be") + if m.IsDeviceTrusted(ownDevice) { + t.Error("Own device trusted while it shouldn't be") + } m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2") - trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID) - require.NoError(t, err, "Error checking if own user is trusted") - assert.True(t, trusted, "Own user not trusted while they should be") - assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be") + if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { + t.Error("Own user not trusted while they should be") + } + if !m.IsDeviceTrusted(ownDevice) { + t.Error("Own device not trusted while it should be") + } } func TestTrustOtherUser(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") - trusted, err := m.IsUserTrusted(context.TODO(), otherUser) - require.NoError(t, err, "Error checking if other user is trusted") - assert.False(t, trusted, "Other user trusted while they shouldn't be") + if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { + t.Error("Other user trusted while they shouldn't be") + } theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -93,16 +100,16 @@ func TestTrustOtherUser(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig") - trusted, err = m.IsUserTrusted(context.TODO(), otherUser) - require.NoError(t, err, "Error checking if other user is trusted") - assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key") + if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { + t.Error("Other user trusted before their master key has been signed with our user-signing key") + } m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - trusted, err = m.IsUserTrusted(context.TODO(), otherUser) - require.NoError(t, err, "Error checking if other user is trusted") - assert.True(t, trusted, "Other user not trusted while they should be") + if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { + t.Error("Other user not trusted while they should be") + } } func TestTrustOtherDevice(t *testing.T) { @@ -113,11 +120,12 @@ func TestTrustOtherDevice(t *testing.T) { DeviceID: "theirDevice", SigningKey: id.Ed25519("theirDeviceKey"), } - - trusted, err := m.IsUserTrusted(context.TODO(), otherUser) - require.NoError(t, err, "Error checking if other user is trusted") - assert.False(t, trusted, "Other user trusted while they shouldn't be") - assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be") + if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { + t.Error("Other user trusted while they shouldn't be") + } + if m.IsDeviceTrusted(theirDevice) { + t.Error("Other device trusted while it shouldn't be") + } theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -129,17 +137,21 @@ func TestTrustOtherDevice(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - trusted, err = m.IsUserTrusted(context.TODO(), otherUser) - require.NoError(t, err, "Error checking if other user is trusted") - assert.True(t, trusted, "Other user not trusted while they should be") + if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { + t.Error("Other user not trusted while they should be") + } m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), otherUser, theirMasterKey.PublicKey(), "sig3") - assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK") + if m.IsDeviceTrusted(theirDevice) { + t.Error("Other device trusted before it has been signed with user's SSK") + } m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey(), "sig4") - assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK") + if !m.IsDeviceTrusted(theirDevice) { + t.Error("Other device not trusted while it should be") + } } diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index 4cdf0dd5..04a179df 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -13,9 +13,6 @@ import ( "maunium.net/go/mautrix/id" ) -// ResolveTrust resolves the trust state of the device from cross-signing. -// -// Deprecated: This method doesn't take a context. Use [OlmMachine.ResolveTrustContext] instead. func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState { state, _ := mach.ResolveTrustContext(context.Background(), device) return state @@ -80,12 +77,8 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi } // IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing. -// -// Note: this will return false if resolving the trust state fails due to database errors. -// Use [OlmMachine.ResolveTrustContext] if special error handling is required. -func (mach *OlmMachine) IsDeviceTrusted(ctx context.Context, device *id.Device) bool { - trust, _ := mach.ResolveTrustContext(ctx, device) - switch trust { +func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool { + switch mach.ResolveTrust(device) { case id.TrustStateVerified, id.TrustStateCrossSignedTOFU, id.TrustStateCrossSignedVerified: return true default: diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index b62dc128..7bb7037d 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -15,7 +15,6 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" - _ "go.mau.fi/util/dbutil/litestream" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -37,12 +36,8 @@ type CryptoHelper struct { DecryptErrorCallback func(*event.Event, error) - MSC4190 bool LoginAs *mautrix.ReqLogin - ASEventProcessor crypto.ASEventProcessor - CustomPostDecrypt func(context.Context, *event.Event) - DBAccountID string } @@ -63,7 +58,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH return nil, fmt.Errorf("pickle key must be provided") } _, isExtensible := cli.Syncer.(mautrix.ExtensibleSyncer) - if !cli.SetAppServiceDeviceID && !isExtensible { + if !isExtensible { return nil, fmt.Errorf("the client syncer must implement ExtensibleSyncer") } @@ -79,7 +74,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH } unmanagedCryptoStore = typedStore case string: - db, err := dbutil.NewWithDialect(fmt.Sprintf("file:%s?_txlock=immediate", typedStore), "sqlite3-fk-wal") + db, err := dbutil.NewWithDialect(typedStore, "sqlite3") if err != nil { return nil, err } @@ -116,9 +111,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { } syncer, ok := helper.client.Syncer.(mautrix.ExtensibleSyncer) if !ok { - if !helper.client.SetAppServiceDeviceID { - return fmt.Errorf("the client syncer must implement ExtensibleSyncer") - } + return fmt.Errorf("the client syncer must implement ExtensibleSyncer") } var stateStore crypto.StateStore @@ -143,42 +136,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to upgrade crypto state store: %w", err) } - cryptoStore = managedCryptoStore - } else { - cryptoStore = helper.unmanagedCryptoStore - } - shouldFindDeviceID := helper.LoginAs != nil || helper.unmanagedCryptoStore == nil - if rawCryptoStore, ok := cryptoStore.(*crypto.SQLCryptoStore); ok && shouldFindDeviceID { - storedDeviceID, err := rawCryptoStore.FindDeviceID(ctx) + storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx) if err != nil { return fmt.Errorf("failed to find existing device ID: %w", err) } - if helper.MSC4190 { - helper.log.Debug().Msg("Creating bot device with MSC4190") - err = helper.client.CreateDeviceMSC4190(ctx, storedDeviceID, helper.LoginAs.InitialDeviceDisplayName) - if err != nil { - return fmt.Errorf("failed to create device for bot: %w", err) - } - rawCryptoStore.DeviceID = helper.client.DeviceID - } else if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { - if storedDeviceID == "" { - helper.log.Debug(). - Str("username", helper.LoginAs.Identifier.User). - Msg("Logging in with appservice") - var resp *mautrix.RespLogin - resp, err = helper.client.Login(ctx, helper.LoginAs) - if err != nil { - return err - } - helper.client.DeviceID = resp.DeviceID - } else { - helper.log.Debug(). - Str("username", helper.LoginAs.Identifier.User). - Stringer("device_id", storedDeviceID). - Msg("Using existing device") - helper.client.DeviceID = storedDeviceID - } - } else if helper.LoginAs != nil { + if helper.LoginAs != nil { if storedDeviceID != "" { helper.LoginAs.DeviceID = storedDeviceID } @@ -191,12 +153,18 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } + if storedDeviceID == "" { + managedCryptoStore.DeviceID = helper.client.DeviceID + } } else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID { return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID) } - rawCryptoStore.DeviceID = helper.client.DeviceID - } else if helper.LoginAs != nil { - return fmt.Errorf("LoginAs can only be used with a managed crypto store") + cryptoStore = managedCryptoStore + } else { + if helper.LoginAs != nil { + return fmt.Errorf("LoginAs can only be used with a managed crypto store") + } + cryptoStore = helper.unmanagedCryptoStore } if helper.client.DeviceID == "" || helper.client.UserID == "" { return fmt.Errorf("the client must be logged in") @@ -209,22 +177,16 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } - if syncer != nil { - syncer.OnSync(helper.mach.ProcessSyncResponse) - syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent) - if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok { - syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted) - } else { - helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.") - } - if helper.managedStateStore != nil { - syncer.OnEvent(helper.client.StateStoreSyncHandler) - } - } else if helper.ASEventProcessor != nil { - helper.mach.AddAppserviceListener(helper.ASEventProcessor) - helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) + syncer.OnSync(helper.mach.ProcessSyncResponse) + syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent) + if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok { + syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted) + } else { + helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.") + } + if helper.managedStateStore != nil { + syncer.OnEvent(helper.client.StateStoreSyncHandler) } - return nil } @@ -261,24 +223,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error if !ok || len(device.Keys) == 0 { if isShared { return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server") + } else { + helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine") + return nil } - helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys") - err = helper.mach.ShareKeys(ctx, -1) - if err != nil { - return fmt.Errorf("failed to share keys: %w", err) - } - return nil } else if !isShared { return fmt.Errorf("olm account is not marked as shared, but there are keys on the server") } else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed { return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed) + } + if !isShared { + helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?") } else { helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine") - return nil } + return nil } -var NoSessionFound = crypto.ErrNoSessionFound +var NoSessionFound = crypto.NoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -297,25 +259,29 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) - if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" { - go helper.waitForSession(ctx, evt) - } else if err != nil { + if errors.Is(err, NoSessionFound) { + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err = helper.Decrypt(ctx, evt) + } else { + go helper.waitLongerForSession(ctx, log, evt) + return + } + } + if err != nil { log.Warn().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) - } else { - helper.postDecrypt(ctx, decrypted) + return } + helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { decrypted.Mautrix.EventSource |= event.SourceDecrypted - if helper.CustomPostDecrypt != nil { - helper.CustomPostDecrypt(ctx, decrypted) - } else if helper.ASEventProcessor != nil { - helper.ASEventProcessor.Dispatch(ctx, decrypted) - } else { - helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) - } + helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { @@ -345,33 +311,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) { - log := zerolog.Ctx(ctx) - content := evt.Content.AsEncrypted() - - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err := helper.Decrypt(ctx, evt) - if err != nil { - log.Warn().Err(err).Msg("Failed to decrypt event") - helper.DecryptErrorCallback(evt, err) - } else { - helper.postDecrypt(ctx, decrypted) - } - } else { - go helper.waitLongerForSession(ctx, evt) - } -} - -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) { - log := zerolog.Ctx(ctx) +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") - //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -419,7 +362,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R defer helper.lock.RUnlock() encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { + if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { return } helper.log.Debug(). diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 457d5a0c..00f99ce4 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -24,23 +24,13 @@ import ( ) var ( - ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") - ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") - ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - ErrRatchetError = errors.New("failed to ratchet session after use") - ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") -) - -// Deprecated: use variables prefixed with Err -var ( - IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType - NoSessionFound = ErrNoSessionFound - DuplicateMessageIndex = ErrDuplicateMessageIndex - WrongRoom = ErrWrongRoom - DeviceKeyMismatch = ErrDeviceKeyMismatch - RatchetError = ErrRatchetError + IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + DuplicateMessageIndex = errors.New("duplicate megolm message index") + WrongRoom = errors.New("encrypted megolm event is not intended for this room") + DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") + RatchetError = errors.New("failed to ratchet session after use") ) type megolmEvent struct { @@ -55,30 +45,13 @@ var ( relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") ) -const sessionIDLength = 43 - -func validateCiphertextCharacters(ciphertext []byte) bool { - for _, b := range ciphertext { - if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' { - return false - } - } - return true -} - // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, ErrIncorrectEncryptedContentType + return nil, IncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, ErrUnsupportedAlgorithm - } else if len(content.MegolmCiphertext) < 74 { - return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext)) - } else if len(content.SessionID) != sessionIDLength { - return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID)) - } else if !validateCiphertextCharacters(content.MegolmCiphertext) { - return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload) + return nil, UnsupportedAlgorithm } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -124,13 +97,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - log.Debug(). - Stringer("session_sender_key", sess.SenderKey). - Stringer("device_sender_key", device.IdentityKey). - Stringer("session_signing_key", sess.SigningKey). - Stringer("device_signing_key", device.SigningKey). - Msg("Device keys don't match keys in session, marking as untrusted") - trustLevel = id.TrustStateDeviceKeyMismatch + return nil, DeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { @@ -180,9 +147,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, ErrWrongRoom + return nil, WrongRoom } - if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { + if evt.StateKey != nil && megolmEvt.StateKey != nil { megolmEvt.Type.Class = event.StateEventType } else { megolmEvt.Type.Class = evt.Type.Class @@ -193,7 +160,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { if errors.Is(err, event.ErrUnsupportedContentType) { log.Warn().Msg("Unsupported event type in encrypted event") - } else if !mach.IgnorePostDecryptionParseErrors { + } else { return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } } @@ -213,7 +180,6 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event TrustSource: device, ForwardedKeys: forwardedKeys, WasEncrypted: true, - EventSource: evt.Mautrix.EventSource | event.SourceDecrypted, ReceivedAt: evt.Mautrix.ReceivedAt, }, }, nil @@ -235,19 +201,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { @@ -258,11 +224,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) + return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) + } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { + return sess, nil, 0, SenderKeyMismatch } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -270,7 +238,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) } // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function @@ -322,24 +290,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index aea5e6dc..55614b76 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -8,45 +8,26 @@ package crypto import ( "context" - "crypto/sha256" - "encoding/base64" "encoding/json" "errors" "fmt" - "slices" "time" "github.com/rs/zerolog" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/ptr" - "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( - ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") - ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - ErrSenderMismatch = errors.New("mismatched sender in olm payload") - ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") - ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") - ErrDuplicateMessage = errors.New("duplicate olm message") -) - -// Deprecated: use variables prefixed with Err -var ( - UnsupportedAlgorithm = ErrUnsupportedAlgorithm - NotEncryptedForMe = ErrNotEncryptedForMe - UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType - DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession - DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage - SenderMismatch = ErrSenderMismatch - RecipientMismatch = ErrRecipientMismatch - RecipientKeyMismatch = ErrRecipientKeyMismatch + UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + UnsupportedOlmMessageType = errors.New("unsupported olm message type") + DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + SenderMismatch = errors.New("mismatched sender in olm payload") + RecipientMismatch = errors.New("mismatched recipient in olm payload") + RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -68,13 +49,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, ErrIncorrectEncryptedContentType + return nil, IncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, ErrUnsupportedAlgorithm + return nil, UnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, ErrNotEncryptedForMe + return nil, NotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -90,14 +71,9 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, ErrUnsupportedOlmMessageType + return nil, UnsupportedOlmMessageType } - log := mach.machOrContextLog(ctx).With(). - Stringer("sender_key", senderKey). - Int("olm_msg_type", int(olmType)). - Logger() - ctx = log.WithContext(ctx) endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second) plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext) endTimeTrace() @@ -114,18 +90,16 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, ErrSenderMismatch + return nil, SenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, ErrRecipientMismatch + return nil, RecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, ErrRecipientKeyMismatch + return nil, RecipientKeyMismatch } - if len(olmEvt.Content.VeryRaw) > 0 { - err = olmEvt.Content.ParseRaw(olmEvt.Type) - if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { - return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) - } + err = olmEvt.Content.ParseRaw(olmEvt.Type) + if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { + return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) } olmEvt.SenderKey = senderKey @@ -133,40 +107,16 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e return &olmEvt, nil } -func olmMessageHash(ciphertext string) ([32]byte, error) { - if ciphertext == "" { - return [32]byte{}, fmt.Errorf("empty ciphertext") - } - ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) - return sha256.Sum256(ciphertextBytes), err -} - func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { - ciphertextHash, err := olmMessageHash(ciphertext) - if err != nil { - return nil, fmt.Errorf("failed to hash olm ciphertext: %w", err) - } - log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second) mach.olmLock.Lock() endTimeTrace() defer mach.olmLock.Unlock() - duplicateTS, err := mach.CryptoStore.GetOlmHash(ctx, ciphertextHash) + plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext) if err != nil { - log.Warn().Err(err).Msg("Failed to check for duplicate olm message") - } else if !duplicateTS.IsZero() { - log.Warn(). - Hex("ciphertext_hash", ciphertextHash[:]). - Time("duplicate_ts", duplicateTS). - Msg("Ignoring duplicate olm message") - return nil, ErrDuplicateMessage - } - - plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) - if err != nil { - if err == ErrDecryptionFailedWithMatchingSession { + if err == DecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -184,10 +134,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, ErrDecryptionFailedForNormalMessage + return nil, DecryptionFailedForNormalMessage } - accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -198,8 +147,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } log = log.With().Str("new_olm_session_id", session.ID().String()).Logger() log.Debug(). - Hex("ciphertext_hash", ciphertextHash[:]). - Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("olm_session_description", session.Describe()). Msg("Created inbound olm session") ctx = log.WithContext(ctx) @@ -208,28 +155,11 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err = session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { - log.Debug(). - Hex("ciphertext_hash", ciphertextHash[:]). - Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). - Str("ciphertext", ciphertext). - Str("olm_session_description", session.Describe()). - Msg("DEBUG: Failed to decrypt prekey olm message with newly created session") - err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup) - if err2 != nil { - log.Debug().Err(err2).Msg("Goolm confirmed decryption failure") - } else { - log.Warn().Msg("Goolm decryption was successful after libolm failure?") - } - go mach.unwedgeDevice(log, sender, senderKey) return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second) - err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) - if err != nil { - log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") - } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { @@ -238,28 +168,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return plaintext, nil } -func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error { - acc, err := account.AccountFromPickled(accountBackup, []byte("tmp")) - if err != nil { - return fmt.Errorf("failed to unpickle olm account: %w", err) - } - sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext) - if err != nil { - return fmt.Errorf("failed to create inbound session: %w", err) - } - _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey) - if err != nil { - // This is the expected result if libolm failed - return fmt.Errorf("failed to decrypt with new session: %w", err) - } - return nil -} - -const MaxOlmSessionsPerDevice = 5 - -func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( - ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, ciphertextHash [32]byte, -) ([]byte, error) { +func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey) @@ -267,32 +176,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( if err != nil { return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) } - if len(sessions) > MaxOlmSessionsPerDevice*2 { - // SQL store sorts sessions, but other implementations may not, so re-sort just in case - slices.SortFunc(sessions, func(a, b *OlmSession) int { - return b.LastDecryptedTime.Compare(a.LastDecryptedTime) - }) - log.Warn(). - Int("session_count", len(sessions)). - Time("newest_last_decrypted_at", sessions[0].LastDecryptedTime). - Time("oldest_last_decrypted_at", sessions[len(sessions)-1].LastDecryptedTime). - Msg("Too many sessions, deleting old ones") - for i := MaxOlmSessionsPerDevice; i < len(sessions); i++ { - err = mach.CryptoStore.DeleteSession(ctx, senderKey, sessions[i]) - if err != nil { - log.Warn().Err(err). - Stringer("olm_session_id", sessions[i].ID()). - Time("last_decrypt", sessions[i].LastDecryptedTime). - Msg("Failed to delete olm session") - } else { - log.Debug(). - Stringer("olm_session_id", sessions[i].ID()). - Time("last_decrypt", sessions[i].LastDecryptedTime). - Msg("Deleted olm session") - } - } - sessions = sessions[:MaxOlmSessionsPerDevice] - } for _, session := range sessions { log := log.With().Str("olm_session_id", session.ID().String()).Logger() @@ -307,33 +190,22 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( continue } } + log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message") endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second) plaintext, err := session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { - log.Warn().Err(err). - Hex("ciphertext_hash", ciphertextHash[:]). - Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). - Str("session_description", session.Describe()). - Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, ErrDecryptionFailedWithMatchingSession + return nil, DecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) - err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) - if err != nil { - log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") - } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") } - log.Debug(). - Hex("ciphertext_hash", ciphertextHash[:]). - Str("session_description", session.Describe()). - Msg("Decrypted olm message") + log.Debug().Msg("Decrypted olm message") return plaintext, nil } } @@ -357,10 +229,10 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(mach.backgroundCtx) + ctx := log.WithContext(context.TODO()) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Since(prevUnwedge) + delta := time.Now().Sub(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). @@ -371,17 +243,6 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send mach.recentlyUnwedged[senderKey] = time.Now() mach.recentlyUnwedgedLock.Unlock() - lastCreatedAt, err := mach.CryptoStore.GetNewestSessionCreationTS(ctx, senderKey) - if err != nil { - log.Warn().Err(err).Msg("Failed to get newest session creation timestamp") - return - } else if time.Since(lastCreatedAt) < MinUnwedgeInterval { - log.Debug(). - Time("last_created_at", lastCreatedAt). - Msg("Not creating new Olm session as it was already recreated recently") - return - } - deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey) if err != nil { log.Error().Err(err).Msg("Failed to find device info by identity key") @@ -391,10 +252,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send return } - log.Debug(). - Time("last_created", lastCreatedAt). - Stringer("device_id", deviceIdentity.DeviceID). - Msg("Creating new Olm session") + log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session") mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() diff --git a/crypto/devicelist.go b/crypto/devicelist.go index f0d2b129..de6c21f3 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -10,8 +10,6 @@ import ( "context" "errors" "fmt" - "slices" - "strings" "github.com/rs/zerolog" "go.mau.fi/util/exzerolog" @@ -22,23 +20,12 @@ import ( ) var ( - ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - ErrMismatchingSigningKey = errors.New("received update for device with different signing key") - ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") - ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - ErrInvalidKeySignature = errors.New("invalid signature on device keys") - ErrUserNotTracked = errors.New("user is not tracked") -) - -// Deprecated: use variables prefixed with Err -var ( - MismatchingDeviceID = ErrMismatchingDeviceID - MismatchingUserID = ErrMismatchingUserID - MismatchingSigningKey = ErrMismatchingSigningKey - NoSigningKeyFound = ErrNoSigningKeyFound - NoIdentityKeyFound = ErrNoIdentityKeyFound - InvalidKeySignature = ErrInvalidKeySignature + MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + MismatchingSigningKey = errors.New("received update for device with different signing key") + NoSigningKeyFound = errors.New("didn't find ed25519 signing key") + NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + InvalidKeySignature = errors.New("invalid signature on device keys") ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -53,81 +40,6 @@ func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys m return nil } -type CachedDevices struct { - Devices []*id.Device - MasterKey *id.CrossSigningKey - HasValidSelfSigningKey bool - MasterKeySignedByUs bool -} - -func (mach *OlmMachine) GetCachedDevices(ctx context.Context, userID id.UserID) (*CachedDevices, error) { - userIDs, err := mach.CryptoStore.FilterTrackedUsers(ctx, []id.UserID{userID}) - if err != nil { - return nil, fmt.Errorf("failed to check if user's devices are tracked: %w", err) - } else if len(userIDs) == 0 { - return nil, ErrUserNotTracked - } - ownKeys := mach.GetOwnCrossSigningPublicKeys(ctx) - var ownUserSigningKey id.Ed25519 - if ownKeys != nil { - ownUserSigningKey = ownKeys.UserSigningKey - } - var resp CachedDevices - csKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) - theirMasterKey := csKeys[id.XSUsageMaster] - theirSelfSignKey := csKeys[id.XSUsageSelfSigning] - if err != nil { - return nil, fmt.Errorf("failed to get cross-signing keys: %w", err) - } else if csKeys != nil && theirMasterKey.Key != "" { - resp.MasterKey = &theirMasterKey - if theirSelfSignKey.Key != "" { - resp.HasValidSelfSigningKey, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirSelfSignKey.Key, userID, theirMasterKey.Key) - if err != nil { - return nil, fmt.Errorf("failed to check if self-signing key is signed by master key: %w", err) - } - } - } - devices, err := mach.CryptoStore.GetDevices(ctx, userID) - if err != nil { - return nil, fmt.Errorf("failed to get devices: %w", err) - } - if userID == mach.Client.UserID { - if ownKeys != nil && ownKeys.MasterKey == theirMasterKey.Key { - resp.MasterKeySignedByUs, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMasterKey.Key, userID, mach.OwnIdentity().SigningKey) - } - } else if ownUserSigningKey != "" && theirMasterKey.Key != "" { - // TODO should own master key and user-signing key signatures be checked here too? - resp.MasterKeySignedByUs, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMasterKey.Key, mach.Client.UserID, ownUserSigningKey) - } - if err != nil { - return nil, fmt.Errorf("failed to check if user is trusted: %w", err) - } - resp.Devices = make([]*id.Device, len(devices)) - i := 0 - for _, device := range devices { - resp.Devices[i] = device - if resp.HasValidSelfSigningKey && device.Trust == id.TrustStateUnset { - signed, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSelfSignKey.Key) - if err != nil { - return nil, fmt.Errorf("failed to check if device %s is signed by self-signing key: %w", device.DeviceID, err) - } else if signed { - if resp.MasterKeySignedByUs { - device.Trust = id.TrustStateCrossSignedVerified - } else if theirMasterKey.Key == theirMasterKey.First { - device.Trust = id.TrustStateCrossSignedTOFU - } else { - device.Trust = id.TrustStateCrossSignedUntrusted - } - } - } - i++ - } - slices.SortFunc(resp.Devices, func(a, b *id.Device) int { - return strings.Compare(a.DeviceID.String(), b.DeviceID.String()) - }) - return &resp, nil -} - func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) { log := zerolog.Ctx(ctx) deviceKeys := resp.DeviceKeys[userID][deviceID] @@ -215,7 +127,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received") data = make(map[id.UserID]map[id.DeviceID]*id.Device) for userID, devices := range resp.DeviceKeys { - log := log.With().Stringer("user_id", userID).Logger() + log := log.With().Str("user_id", userID.String()).Logger() delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) @@ -231,7 +143,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ Msg("Updating devices in store") changed := false for deviceID, deviceKeys := range devices { - log := log.With().Stringer("device_id", deviceID).Logger() + log := log.With().Str("device_id", deviceID.String()).Logger() existing, ok := existingDevices[deviceID] if !ok { // New device @@ -279,7 +191,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ } } for userID := range req.DeviceKeys { - log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user") + log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user") } mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys) @@ -321,28 +233,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, ErrNoSigningKeyFound + return nil, NoSigningKeyFound } else if identityKey == "" { - return nil, ErrNoIdentityKeyFound + return nil, NoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, ErrInvalidKeySignature + return existing, InvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/ed25519/ed25519.go b/crypto/ed25519/ed25519.go deleted file mode 100644 index 327cbb3c..00000000 --- a/crypto/ed25519/ed25519.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2024 Sumner Evans. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package ed25519 implements the Ed25519 signature algorithm. See -// https://ed25519.cr.yp.to/. -// -// This package stores the private key in the NaCl format, which is a different -// format than that used by the [crypto/ed25519] package in the standard -// library. -// -// This picture will help with the rest of the explanation: -// https://blog.mozilla.org/warner/files/2011/11/key-formats.png -// -// The private key in the [crypto/ed25519] package is a 64-byte value where the -// first 32-bytes are the seed and the last 32-bytes are the public key. -// -// The private key in this package is stored in the NaCl format. That is, the -// left 32-bytes are the private scalar A and the right 32-bytes are the right -// half of the SHA512 result. -// -// The contents of this package are mostly copied from the standard library, -// and as such the source code is licensed under the BSD license of the -// standard library implementation. -// -// Other notable changes from the standard library include: -// -// - The Seed function of the standard library is not implemented in this -// package because there is no way to recover the seed after hashing it. -package ed25519 - -import ( - "crypto" - "crypto/ed25519" - cryptorand "crypto/rand" - "crypto/sha512" - "crypto/subtle" - "errors" - "io" - "strconv" - - "filippo.io/edwards25519" -) - -const ( - // PublicKeySize is the size, in bytes, of public keys as used in this package. - PublicKeySize = 32 - // PrivateKeySize is the size, in bytes, of private keys as used in this package. - PrivateKeySize = 64 - // SignatureSize is the size, in bytes, of signatures generated and verified by this package. - SignatureSize = 64 - // SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032. - SeedSize = 32 -) - -// PublicKey is the type of Ed25519 public keys. -type PublicKey []byte - -// Any methods implemented on PublicKey might need to also be implemented on -// PrivateKey, as the latter embeds the former and will expose its methods. - -// Equal reports whether pub and x have the same value. -func (pub PublicKey) Equal(x crypto.PublicKey) bool { - switch x := x.(type) { - case PublicKey: - return subtle.ConstantTimeCompare(pub, x) == 1 - case ed25519.PublicKey: - return subtle.ConstantTimeCompare(pub, x) == 1 - default: - return false - } -} - -// PrivateKey is the type of Ed25519 private keys. It implements [crypto.Signer]. -type PrivateKey []byte - -// Public returns the [PublicKey] corresponding to priv. -// -// This method differs from the standard library because it calculates the -// public key instead of returning the right half of the private key (which -// contains the public key in the standard library). -func (priv PrivateKey) Public() crypto.PublicKey { - s, err := edwards25519.NewScalar().SetBytesWithClamping(priv[:32]) - if err != nil { - panic("ed25519: internal error: setting scalar failed") - } - return (&edwards25519.Point{}).ScalarBaseMult(s).Bytes() -} - -// Equal reports whether priv and x have the same value. -func (priv PrivateKey) Equal(x crypto.PrivateKey) bool { - // TODO do we have any need to check equality with standard library ed25519 - // private keys? - xx, ok := x.(PrivateKey) - if !ok { - return false - } - return subtle.ConstantTimeCompare(priv, xx) == 1 -} - -// Sign signs the given message with priv. rand is ignored and can be nil. -// -// If opts.HashFunc() is [crypto.SHA512], the pre-hashed variant Ed25519ph is used -// and message is expected to be a SHA-512 hash, otherwise opts.HashFunc() must -// be [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two -// passes over messages to be signed. -// -// A value of type [Options] can be used as opts, or crypto.Hash(0) or -// crypto.SHA512 directly to select plain Ed25519 or Ed25519ph, respectively. -func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) { - hash := opts.HashFunc() - context := "" - if opts, ok := opts.(*Options); ok { - context = opts.Context - } - switch { - case hash == crypto.SHA512: // Ed25519ph - if l := len(message); l != sha512.Size { - return nil, errors.New("ed25519: bad Ed25519ph message hash length: " + strconv.Itoa(l)) - } - if l := len(context); l > 255 { - return nil, errors.New("ed25519: bad Ed25519ph context length: " + strconv.Itoa(l)) - } - signature := make([]byte, SignatureSize) - sign(signature, priv, message, domPrefixPh, context) - return signature, nil - case hash == crypto.Hash(0) && context != "": // Ed25519ctx - if l := len(context); l > 255 { - return nil, errors.New("ed25519: bad Ed25519ctx context length: " + strconv.Itoa(l)) - } - signature := make([]byte, SignatureSize) - sign(signature, priv, message, domPrefixCtx, context) - return signature, nil - case hash == crypto.Hash(0): // Ed25519 - return Sign(priv, message), nil - default: - return nil, errors.New("ed25519: expected opts.HashFunc() zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)") - } -} - -// Options can be used with [PrivateKey.Sign] or [VerifyWithOptions] -// to select Ed25519 variants. -type Options struct { - // Hash can be zero for regular Ed25519, or crypto.SHA512 for Ed25519ph. - Hash crypto.Hash - - // Context, if not empty, selects Ed25519ctx or provides the context string - // for Ed25519ph. It can be at most 255 bytes in length. - Context string -} - -// HashFunc returns o.Hash. -func (o *Options) HashFunc() crypto.Hash { return o.Hash } - -// GenerateKey generates a public/private key pair using entropy from rand. -// If rand is nil, [crypto/rand.Reader] will be used. -// -// The output of this function is deterministic, and equivalent to reading -// [SeedSize] bytes from rand, and passing them to [NewKeyFromSeed]. -func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) { - if rand == nil { - rand = cryptorand.Reader - } - - seed := make([]byte, SeedSize) - if _, err := io.ReadFull(rand, seed); err != nil { - return nil, nil, err - } - - privateKey := NewKeyFromSeed(seed) - return PublicKey(privateKey.Public().([]byte)), privateKey, nil -} - -// NewKeyFromSeed calculates a private key from a seed. It will panic if -// len(seed) is not [SeedSize]. This function is provided for interoperability -// with RFC 8032. RFC 8032's private keys correspond to seeds in this -// package. -func NewKeyFromSeed(seed []byte) PrivateKey { - // Outline the function body so that the returned key can be stack-allocated. - privateKey := make([]byte, PrivateKeySize) - newKeyFromSeed(privateKey, seed) - return privateKey -} - -func newKeyFromSeed(privateKey, seed []byte) { - if l := len(seed); l != SeedSize { - panic("ed25519: bad seed length: " + strconv.Itoa(l)) - } - - h := sha512.Sum512(seed) - - // Apply clamping to get A in the left half, and leave the right half - // as-is. This gets the private key into the NaCl format. - h[0] &= 248 - h[31] &= 63 - h[31] |= 64 - copy(privateKey, h[:]) -} - -// Sign signs the message with privateKey and returns a signature. It will -// panic if len(privateKey) is not [PrivateKeySize]. -func Sign(privateKey PrivateKey, message []byte) []byte { - // Outline the function body so that the returned signature can be - // stack-allocated. - signature := make([]byte, SignatureSize) - sign(signature, privateKey, message, domPrefixPure, "") - return signature -} - -// Domain separation prefixes used to disambiguate Ed25519/Ed25519ph/Ed25519ctx. -// See RFC 8032, Section 2 and Section 5.1. -const ( - // domPrefixPure is empty for pure Ed25519. - domPrefixPure = "" - // domPrefixPh is dom2(phflag=1) for Ed25519ph. It must be followed by the - // uint8-length prefixed context. - domPrefixPh = "SigEd25519 no Ed25519 collisions\x01" - // domPrefixCtx is dom2(phflag=0) for Ed25519ctx. It must be followed by the - // uint8-length prefixed context. - domPrefixCtx = "SigEd25519 no Ed25519 collisions\x00" -) - -func sign(signature []byte, privateKey PrivateKey, message []byte, domPrefix, context string) { - if l := len(privateKey); l != PrivateKeySize { - panic("ed25519: bad private key length: " + strconv.Itoa(l)) - } - // We have to extract the public key from the private key. - publicKey := privateKey.Public().([]byte) - // The private key is already the hashed value of the seed. - h := privateKey - - s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) - if err != nil { - panic("ed25519: internal error: setting scalar failed") - } - prefix := h[32:] - - mh := sha512.New() - if domPrefix != domPrefixPure { - mh.Write([]byte(domPrefix)) - mh.Write([]byte{byte(len(context))}) - mh.Write([]byte(context)) - } - mh.Write(prefix) - mh.Write(message) - messageDigest := make([]byte, 0, sha512.Size) - messageDigest = mh.Sum(messageDigest) - r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) - if err != nil { - panic("ed25519: internal error: setting scalar failed") - } - - R := (&edwards25519.Point{}).ScalarBaseMult(r) - - kh := sha512.New() - if domPrefix != domPrefixPure { - kh.Write([]byte(domPrefix)) - kh.Write([]byte{byte(len(context))}) - kh.Write([]byte(context)) - } - kh.Write(R.Bytes()) - kh.Write(publicKey) - kh.Write(message) - hramDigest := make([]byte, 0, sha512.Size) - hramDigest = kh.Sum(hramDigest) - k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) - if err != nil { - panic("ed25519: internal error: setting scalar failed") - } - - S := edwards25519.NewScalar().MultiplyAdd(k, s, r) - - copy(signature[:32], R.Bytes()) - copy(signature[32:], S.Bytes()) -} - -// Verify reports whether sig is a valid signature of message by publicKey. It -// will panic if len(publicKey) is not [PublicKeySize]. -// -// This is just a wrapper around [ed25519.Verify] from the standard library. -func Verify(publicKey PublicKey, message, sig []byte) bool { - return ed25519.Verify(ed25519.PublicKey(publicKey), message, sig) -} - -// VerifyWithOptions reports whether sig is a valid signature of message by -// publicKey. A valid signature is indicated by returning a nil error. It will -// panic if len(publicKey) is not [PublicKeySize]. -// -// If opts.Hash is [crypto.SHA512], the pre-hashed variant Ed25519ph is used and -// message is expected to be a SHA-512 hash, otherwise opts.Hash must be -// [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two -// passes over messages to be signed. -// -// This is just a wrapper around [ed25519.VerifyWithOptions] from the standard -// library. -func VerifyWithOptions(publicKey PublicKey, message, sig []byte, opts *Options) error { - return ed25519.VerifyWithOptions(ed25519.PublicKey(publicKey), message, sig, &ed25519.Options{ - Hash: opts.Hash, - Context: opts.Context, - }) -} diff --git a/crypto/ed25519/ed25519_test.go b/crypto/ed25519/ed25519_test.go deleted file mode 100644 index 931c06f6..00000000 --- a/crypto/ed25519/ed25519_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package ed25519_test - -import ( - stdlibed25519 "crypto/ed25519" - "testing" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/random" - - "maunium.net/go/mautrix/crypto/ed25519" -) - -func TestPubkeyEqual(t *testing.T) { - pubkeyBytes := random.Bytes(32) - pubkey := ed25519.PublicKey(pubkeyBytes) - pubkey2 := ed25519.PublicKey(pubkeyBytes) - stdlibPubkey := stdlibed25519.PublicKey(pubkeyBytes) - assert.True(t, pubkey.Equal(pubkey2)) - assert.True(t, pubkey.Equal(stdlibPubkey)) -} diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 88f9c8d4..93fe6409 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -15,8 +15,6 @@ import ( "fmt" "github.com/rs/zerolog" - "github.com/tidwall/gjson" - "go.mau.fi/util/exgjson" "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" @@ -25,32 +23,11 @@ import ( ) var ( - ErrNoGroupSession = errors.New("no group session created") + AlreadyShared = errors.New("group session already shared") + NoGroupSession = errors.New("no group session created") ) -// Deprecated: use variables prefixed with Err -var ( - NoGroupSession = ErrNoGroupSession -) - -func getRawJSON[T any](content json.RawMessage, path ...string) *T { - value := gjson.GetBytes(content, exgjson.Path(path...)) - if !value.IsObject() { - return nil - } - var result T - err := json.Unmarshal([]byte(value.Raw), &result) - if err != nil { - return nil - } - return &result -} - -func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { - contentJSON, ok := content.(json.RawMessage) - if ok { - return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") - } +func getRelatesTo(content interface{}) *event.RelatesTo { contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed @@ -59,14 +36,10 @@ func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { if ok { return relatable.OptionalGetRelatesTo() } - return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to") + return nil } -func getMentions(content any) *event.Mentions { - contentJSON, ok := content.(json.RawMessage) - if ok { - return getRawJSON[event.Mentions](contentJSON, "m.mentions") - } +func getMentions(content interface{}) *event.Mentions { contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed @@ -87,20 +60,15 @@ type rawMegolmEvent struct { // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession + return err == SessionExpired || err == SessionNotShared || err == NoGroupSession } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { - if len(ciphertext) == 0 { - return 0, fmt.Errorf("empty ciphertext") - } decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) if err != nil { return 0, err - } else if len(decoded) < 2+binary.MaxVarintLen64 { - return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded)) } else if decoded[0] != 3 || decoded[1] != 8 { return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1]) } @@ -130,7 +98,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, ErrNoGroupSession + return nil, NoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, @@ -168,21 +136,12 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), MegolmCiphertext: ciphertext, - RelatesTo: getRelatesTo(content, plaintext), + RelatesTo: getRelatesTo(content), // These are deprecated SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, } - if mach.MSC4392Relations && encrypted.RelatesTo != nil { - // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content. - // Other relations like threads are still left unencrypted. - encrypted.RelatesTo.InReplyTo = nil - encrypted.RelatesTo.IsFallingBack = false - if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" { - encrypted.RelatesTo = nil - } - } if mach.PlaintextMentions { encrypted.Mentions = getMentions(content) } @@ -197,10 +156,7 @@ func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.R Msg("Failed to get encryption event in room") return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err) } - session, err := NewOutboundGroupSession(roomID, encryptionEvent) - if err != nil { - return nil, err - } + session := NewOutboundGroupSession(roomID, encryptionEvent) if !mach.DontStoreOutboundKeys { signingKey, idKey := mach.account.Keys() err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) @@ -227,8 +183,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { - mach.machOrContextLog(ctx).Debug().Stringer("room_id", roomID).Msg("Not re-sharing group session, already shared") - return nil + return AlreadyShared } log := mach.machOrContextLog(ctx).With(). Str("room_id", roomID.String()). @@ -252,7 +207,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, var fetchKeysForUsers []id.UserID for _, userID := range users { - log := log.With().Stringer("target_user_id", userID).Logger() + log := log.With().Str("target_user_id", userID.String()).Logger() devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Err(err).Msg("Failed to get devices of user") @@ -324,7 +279,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, toDeviceWithheld.Messages[userID] = withheld } - log := log.With().Stringer("target_user_id", userID).Logger() + log := log.With().Str("target_user_id", userID.String()).Logger() log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)") mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil) log.Debug(). @@ -370,39 +325,41 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session log.Trace().Msg("Encrypting group session for all found devices") deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} - logUsers := zerolog.Dict() for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } - logDevices := zerolog.Dict() output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { + log.Trace(). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} - logDevices.Str(string(deviceID), string(device.identity.IdentityKey)) deviceCount++ + log.Debug(). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { log.Warn(). Err(err). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). Stringer("target_session_id", session.id). Msg("Failed to mark outbound group session shared") } } } - logUsers.Dict(string(userID), logDevices) } log.Debug(). Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). - Dict("destination_map", logUsers). Msg("Sending to-device messages to share group session") _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err @@ -411,9 +368,8 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) { for deviceID, device := range devices { log := zerolog.Ctx(ctx).With(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.IdentityKey). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). Logger() userKey := UserDevice{UserID: userID, DeviceID: deviceID} if state := session.Users[userKey]; state != OGSNotShared { @@ -431,7 +387,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out Reason: "Device is blacklisted", }} session.Users[userKey] = OGSIgnored - } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState < mach.SendKeysMinTrust { + } else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 765307af..15e9df29 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -17,70 +17,6 @@ import ( "maunium.net/go/mautrix/id" ) -func (mach *OlmMachine) EncryptToDevices(ctx context.Context, eventType event.Type, req *mautrix.ReqSendToDevice) (*mautrix.ReqSendToDevice, error) { - devicesToCreateSessions := make(map[id.UserID]map[id.DeviceID]*id.Device) - for userID, devices := range req.Messages { - for deviceID := range devices { - device, err := mach.GetOrFetchDevice(ctx, userID, deviceID) - if err != nil { - return nil, fmt.Errorf("failed to get device %s of user %s: %w", deviceID, userID, err) - } - - if _, ok := devicesToCreateSessions[userID]; !ok { - devicesToCreateSessions[userID] = make(map[id.DeviceID]*id.Device) - } - devicesToCreateSessions[userID][deviceID] = device - } - } - if err := mach.createOutboundSessions(ctx, devicesToCreateSessions); err != nil { - return nil, fmt.Errorf("failed to create outbound sessions: %w", err) - } - - mach.olmLock.Lock() - defer mach.olmLock.Unlock() - - encryptedReq := &mautrix.ReqSendToDevice{ - Messages: make(map[id.UserID]map[id.DeviceID]*event.Content), - } - - log := mach.machOrContextLog(ctx) - - for userID, devices := range req.Messages { - encryptedReq.Messages[userID] = make(map[id.DeviceID]*event.Content) - - for deviceID, content := range devices { - device := devicesToCreateSessions[userID][deviceID] - - olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey) - if err != nil { - return nil, fmt.Errorf("failed to get latest session for device %s of %s: %w", deviceID, userID, err) - } else if olmSess == nil { - log.Warn(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). - Str("identity_key", device.IdentityKey.String()). - Msg("No outbound session found for device") - continue - } - - encrypted := mach.encryptOlmEvent(ctx, olmSess, device, eventType, *content) - encryptedContent := &event.Content{Parsed: &encrypted} - - log.Debug(). - Str("decrypted_type", eventType.Type). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). - Str("target_identity_key", device.IdentityKey.String()). - Str("olm_session_id", olmSess.ID().String()). - Msg("Encrypted to-device event") - - encryptedReq.Messages[userID][deviceID] = encryptedContent - } - } - - return encryptedReq, nil -} - func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent { evt := &DecryptedOlmEvent{ Sender: mach.Client.UserID, @@ -96,19 +32,12 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession panic(err) } log := mach.machOrContextLog(ctx) - msgType, ciphertext, err := session.Encrypt(plaintext) - if err != nil { - panic(err) - } - ciphertextStr := string(ciphertext) - ciphertextHash, _ := olmMessageHash(ciphertextStr) log.Debug(). - Stringer("event_type", evtType). Str("recipient_identity_key", recipient.IdentityKey.String()). Str("olm_session_id", session.ID().String()). Str("olm_session_description", session.Describe()). - Hex("ciphertext_hash", ciphertextHash[:]). - Msg("Encrypted olm message") + Msg("Encrypting olm message") + msgType, ciphertext := session.Encrypt(plaintext) err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") @@ -119,7 +48,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, - Body: ciphertextStr, + Body: string(ciphertext), }, }, } diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index b48843a4..4057543a 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -4,14 +4,18 @@ package account import ( "encoding/base64" "encoding/json" + "errors" "fmt" + "io" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -35,36 +39,41 @@ type Account struct { NumFallbackKeys uint8 `json:"number_fallback_keys"` } -// Ensure that Account adheres to the olm.Account interface. -var _ olm.Account = (*Account)(nil) - // AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) } a := &Account{} - return a, a.UnpickleAsJSON(pickled, key) + err := a.UnpickleAsJSON(pickled, key) + if err != nil { + return nil, err + } + return a, nil } // AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) } a := &Account{} - return a, a.Unpickle(pickled, key) + err := a.Unpickle(pickled, key) + if err != nil { + return nil, err + } + return a, nil } -// NewAccount creates a new Account. -func NewAccount() (*Account, error) { +// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation. +func NewAccount(reader io.Reader) (*Account, error) { a := &Account{} - kPEd25519, err := crypto.Ed25519GenerateKey() + kPEd25519, err := crypto.Ed25519GenerateKey(reader) if err != nil { return nil, err } a.IdKeys.Ed25519 = kPEd25519 - kPCurve25519, err := crypto.Curve25519GenerateKey() + kPCurve25519, err := crypto.Curve25519GenerateKey(reader) if err != nil { return nil, err } @@ -73,60 +82,72 @@ func NewAccount() (*Account, error) { } // PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (a *Account) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(a, accountPickleVersionJSON, key) +func (a Account) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(a, accountPickleVersionJSON, key) } // UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Account) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) + return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) } // IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string. -func (a *Account) IdentityKeysJSON() ([]byte, error) { +func (a Account) IdentityKeysJSON() ([]byte, error) { res := struct { Ed25519 string `json:"ed25519"` Curve25519 string `json:"curve25519"` }{} - ed25519, curve25519, err := a.IdentityKeys() - if err != nil { - return nil, err - } + ed25519, curve25519 := a.IdentityKeys() res.Ed25519 = string(ed25519) res.Curve25519 = string(curve25519) return json.Marshal(res) } // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account. -func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { +func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) { ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey)) curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey)) - return ed25519, curve25519, nil + return ed25519, curve25519 } // Sign returns the base64-encoded signature of a message using the Ed25519 key // for this Account. -func (a *Account) Sign(message []byte) ([]byte, error) { +func (a Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - return nil, fmt.Errorf("sign: %w", olm.ErrEmptyInput) - } else if signature, err := a.IdKeys.Ed25519.Sign(message); err != nil { - return nil, err - } else { - return []byte(base64.RawStdEncoding.EncodeToString(signature)), nil + return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput) } + return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil } // OneTimeKeys returns the public parts of the unpublished one time keys of the Account. // // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { +func (a Account) OneTimeKeys() map[string]id.Curve25519 { oneTimeKeys := make(map[string]id.Curve25519) for _, curKey := range a.OTKeys { if !curKey.Published { - oneTimeKeys[curKey.KeyIDEncoded()] = curKey.Key.PublicKey.B64Encoded() + oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded()) } } - return oneTimeKeys, nil + return oneTimeKeys +} + +//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string. +// +//The returned JSON is of format: +/* + { + Curve25519: { + "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", + "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" + } + } +*/ +func (a Account) OneTimeKeysJSON() ([]byte, error) { + res := make(map[string]map[string]id.Curve25519) + otKeys := a.OneTimeKeys() + res["Curve25519"] = otKeys + return json.Marshal(res) } // MarkKeysAsPublished marks the current set of one time keys and the fallback key as being @@ -142,14 +163,14 @@ func (a *Account) MarkKeysAsPublished() { // GenOneTimeKeys generates a number of new one time keys. If the total number // of keys stored by this Account exceeds MaxOneTimeKeys then the older -// keys are discarded. -func (a *Account) GenOneTimeKeys(num uint) error { +// keys are discarded. If reader is nil, crypto/rand is used for the key creation. +func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { for i := uint(0); i < num; i++ { key := crypto.OneTimeKey{ Published: false, ID: a.NextOneTimeKeyID, } - newKP, err := crypto.Curve25519GenerateKey() + newKP, err := crypto.Curve25519GenerateKey(reader) if err != nil { return err } @@ -165,9 +186,9 @@ func (a *Account) GenOneTimeKeys(num uint) error { // NewOutboundSession creates a new outbound session to a // given curve25519 identity Key and one time key. -func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { +func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, fmt.Errorf("outbound session: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput) } theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey)) if err != nil { @@ -177,21 +198,20 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 if err != nil { return nil, err } - return session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) + s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) + if err != nil { + return nil, err + } + return s, nil } -// NewInboundSession creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. -func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { - return a.NewInboundSessionFrom(nil, oneTimeKeyMsg) -} - -// NewInboundSessionFrom creates a new inbound session from an incoming PRE_KEY message. -func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { +// NewInboundSession creates a new inbound session from an incoming PRE_KEY message. +func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) { if len(oneTimeKeyMsg) == 0 { - return nil, fmt.Errorf("inbound session: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput) } var theirIdentityKeyDecoded *crypto.Curve25519PublicKey + var err error if theirIdentityKey != nil { theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey)) if err != nil { @@ -201,10 +221,14 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime theirIdentityKeyDecoded = &theirIdentityKeyCurve } - return session.NewInboundOlmSession(theirIdentityKeyDecoded, []byte(oneTimeKeyMsg), a.searchOTKForOur, a.IdKeys.Curve25519) + s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519) + if err != nil { + return nil, err + } + return s, nil } -func (a *Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { +func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { for curIndex := range a.OTKeys { if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { return &a.OTKeys[curIndex] @@ -220,29 +244,27 @@ func (a *Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.One } // RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s. -func (a *Account) RemoveOneTimeKeys(s olm.Session) error { - toFind := s.(*session.OlmSession).BobOneTimeKey +func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) { + toFind := s.BobOneTimeKey for curIndex := range a.OTKeys { if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { //Remove and return a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1] a.OTKeys = a.OTKeys[:len(a.OTKeys)-1] - return nil + return } } - return nil //if the key is a fallback or prevFallback, don't remove it } -// GenFallbackKey generates a new fallback key. The old fallback key is stored -// in a.PrevFallbackKey overwriting any previous PrevFallbackKey. -func (a *Account) GenFallbackKey() error { +// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation. +func (a *Account) GenFallbackKey(reader io.Reader) error { a.PrevFallbackKey = a.CurrentFallbackKey key := crypto.OneTimeKey{ Published: false, ID: a.NextOneTimeKeyID, } - newKP, err := crypto.Curve25519GenerateKey() + newKP, err := crypto.Curve25519GenerateKey(reader) if err != nil { return err } @@ -257,10 +279,10 @@ func (a *Account) GenFallbackKey() error { // FallbackKey returns the public part of the current fallback key of the Account. // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a *Account) FallbackKey() map[string]id.Curve25519 { +func (a Account) FallbackKey() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 { - keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded() + keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) } return keys } @@ -275,7 +297,7 @@ func (a *Account) FallbackKey() map[string]id.Curve25519 { } } */ -func (a *Account) FallbackKeyJSON() ([]byte, error) { +func (a Account) FallbackKeyJSON() ([]byte, error) { res := make(map[string]map[string]id.Curve25519) fbk := a.FallbackKey() res["curve25519"] = fbk @@ -284,10 +306,10 @@ func (a *Account) FallbackKeyJSON() ([]byte, error) { // FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished. // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 { +func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published { - keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded() + keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) } return keys } @@ -302,7 +324,7 @@ func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 { } } */ -func (a *Account) FallbackKeyUnpublishedJSON() ([]byte, error) { +func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) { res := make(map[string]map[string]id.Curve25519) fbk := a.FallbackKeyUnpublished() res["curve25519"] = fbk @@ -320,50 +342,69 @@ func (a *Account) ForgetOldFallbackKey() { // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Account) Unpickle(pickled, key []byte) error { - decrypted, err := libolmpickle.Unpickle(key, pickled) + decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err } - return a.UnpickleLibOlm(decrypted) + _, err = a.UnpickleLibOlm(decrypted) + return err } -// UnpickleLibOlm unpickles the unencryted value and populates the [Account] accordingly. -func (a *Account) UnpickleLibOlm(buf []byte) error { - decoder := libolmpickle.NewDecoder(buf) - pickledVersion, err := decoder.ReadUInt32() +// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read. +func (a *Account) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) if err != nil { - return err - } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { - return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) - } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair - return err - } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair - return err + return 0, err } - - otkCount, err := decoder.ReadUInt32() + switch pickledVersion { + case accountPickleVersionLibOLM, 3, 2: + default: + return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion) + } + //read ed25519 key pair + readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:]) if err != nil { - return err + return 0, err } - - a.OTKeys = make([]crypto.OneTimeKey, otkCount) - for i := uint32(0); i < otkCount; i++ { - if err := a.OTKeys[i].UnpickleLibOlm(decoder); err != nil { - return err + curPos += readBytes + //read curve25519 key pair + readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + //Read number of onetimeKeys + numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + //Read i one time keys + a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys) + for i := uint32(0); i < numberOTKeys; i++ { + readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err } + curPos += readBytes } - if pickledVersion <= 2 { // version 2 did not have fallback keys a.NumFallbackKeys = 0 } else if pickledVersion == 3 { // version 3 used the published flag to indicate how many fallback keys // were present (we'll have to assume that the keys were published) - if err = a.CurrentFallbackKey.UnpickleLibOlm(decoder); err != nil { - return err - } else if err = a.PrevFallbackKey.UnpickleLibOlm(decoder); err != nil { - return err + readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err } + curPos += readBytes + readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes if a.CurrentFallbackKey.Published { if a.PrevFallbackKey.Published { a.NumFallbackKeys = 2 @@ -374,70 +415,109 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { a.NumFallbackKeys = 0 } } else { - // Read number of fallback keys - a.NumFallbackKeys, err = decoder.ReadUInt8() + //Read number of fallback keys + numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:]) if err != nil { - return err + return 0, err } - for i := 0; i < int(a.NumFallbackKeys); i++ { - switch i { - case 0: - if err = a.CurrentFallbackKey.UnpickleLibOlm(decoder); err != nil { - return err - } - case 1: - if err = a.PrevFallbackKey.UnpickleLibOlm(decoder); err != nil { - return err - } - default: - // Just drain any remaining fallback keys - if err = (&crypto.OneTimeKey{}).UnpickleLibOlm(decoder); err != nil { - return err + curPos += readBytes + a.NumFallbackKeys = numFallbackKeys + if a.NumFallbackKeys >= 1 { + readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + if a.NumFallbackKeys >= 2 { + readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err } + curPos += readBytes } } } - - //Read next onetime key ID - a.NextOneTimeKeyID, err = decoder.ReadUInt32() - return err + //Read next onetime key id + nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + a.NextOneTimeKeyID = nextOTKeyID + return curPos, nil } // Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm(). -func (a *Account) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided +func (a Account) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, a.PickleLen()) + written, err := a.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err } - return libolmpickle.Pickle(key, a.PickleLibOlm()) + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil } -// PickleLibOlm pickles the [Account] and returns the raw bytes. -func (a *Account) PickleLibOlm() []byte { - encoder := libolmpickle.NewEncoder() - encoder.WriteUInt32(accountPickleVersionLibOLM) - a.IdKeys.Ed25519.PickleLibOlm(encoder) - a.IdKeys.Curve25519.PickleLibOlm(encoder) - - // One-Time Keys - encoder.WriteUInt32(uint32(len(a.OTKeys))) - for _, curOTKey := range a.OTKeys { - curOTKey.PickleLibOlm(encoder) +// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (a Account) PickleLibOlm(target []byte) (int, error) { + if len(target) < a.PickleLen() { + return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) } - - // Fallback Keys - encoder.WriteUInt8(a.NumFallbackKeys) + written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target) + writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle account: %w", err) + } + written += writtenEdKey + writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle account: %w", err) + } + written += writtenCurveKey + written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:]) + for _, curOTKey := range a.OTKeys { + writtenOT, err := curOTKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle account: %w", err) + } + written += writtenOT + } + written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:]) if a.NumFallbackKeys >= 1 { - a.CurrentFallbackKey.PickleLibOlm(encoder) + writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle account: %w", err) + } + written += writtenOT + if a.NumFallbackKeys >= 2 { - a.PrevFallbackKey.PickleLibOlm(encoder) + writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle account: %w", err) + } + written += writtenOT } } - encoder.WriteUInt32(a.NextOneTimeKeyID) - return encoder.Bytes() + written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:]) + return written, nil } -// MaxNumberOfOneTimeKeys returns the largest number of one time keys this -// Account can store. -func (a *Account) MaxNumberOfOneTimeKeys() uint { - return uint(MaxOneTimeKeys) +// PickleLen returns the number of bytes the pickled Account will have. +func (a Account) PickleLen() int { + length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM) + length += a.IdKeys.Ed25519.PickleLen() + length += a.IdKeys.Curve25519.PickleLen() + length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys))) + length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen()) + length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys) + length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen()) + length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID) + return length } diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index d0dec5f0..943d8570 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -1,56 +1,92 @@ package account_test import ( + "bytes" "encoding/base64" + "errors" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" ) func TestAccount(t *testing.T) { - firstAccount, err := account.NewAccount() - assert.NoError(t, err) - err = firstAccount.GenFallbackKey() - assert.NoError(t, err) - err = firstAccount.GenOneTimeKeys(2) - assert.NoError(t, err) + firstAccount, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = firstAccount.GenFallbackKey(nil) + if err != nil { + t.Fatal(err) + } + err = firstAccount.GenOneTimeKeys(nil, 2) + if err != nil { + t.Fatal(err) + } encryptionKey := []byte("testkey") - //now pickle account in JSON format pickled, err := firstAccount.PickleAsJSON(encryptionKey) - assert.NoError(t, err) - + if err != nil { + t.Fatal(err) + } //now unpickle into new Account unpickledAccount, err := account.AccountFromJSONPickled(pickled, encryptionKey) - assert.NoError(t, err) - + if err != nil { + t.Fatal(err) + } //check if accounts are the same - assert.Equal(t, firstAccount.NextOneTimeKeyID, unpickledAccount.NextOneTimeKeyID) - assert.Equal(t, firstAccount.CurrentFallbackKey, unpickledAccount.CurrentFallbackKey) - assert.Equal(t, firstAccount.PrevFallbackKey, unpickledAccount.PrevFallbackKey) - assert.Equal(t, firstAccount.OTKeys, unpickledAccount.OTKeys) - assert.Equal(t, firstAccount.IdKeys, unpickledAccount.IdKeys) + if firstAccount.NextOneTimeKeyID != unpickledAccount.NextOneTimeKeyID { + t.Fatal("NextOneTimeKeyID unequal") + } + if !firstAccount.CurrentFallbackKey.Equal(unpickledAccount.CurrentFallbackKey) { + t.Fatal("CurrentFallbackKey unequal") + } + if !firstAccount.PrevFallbackKey.Equal(unpickledAccount.PrevFallbackKey) { + t.Fatal("PrevFallbackKey unequal") + } + if len(firstAccount.OTKeys) != len(unpickledAccount.OTKeys) { + t.Fatal("OneTimeKeysunequal") + } + for i := range firstAccount.OTKeys { + if !firstAccount.OTKeys[i].Equal(unpickledAccount.OTKeys[i]) { + t.Fatalf("OneTimeKeys %d unequal", i) + } + } + if !firstAccount.IdKeys.Curve25519.PrivateKey.Equal(unpickledAccount.IdKeys.Curve25519.PrivateKey) { + t.Fatal("IdentityKeys Curve25519 private unequal") + } + if !firstAccount.IdKeys.Curve25519.PublicKey.Equal(unpickledAccount.IdKeys.Curve25519.PublicKey) { + t.Fatal("IdentityKeys Curve25519 public unequal") + } + if !firstAccount.IdKeys.Ed25519.PrivateKey.Equal(unpickledAccount.IdKeys.Ed25519.PrivateKey) { + t.Fatal("IdentityKeys Ed25519 private unequal") + } + if !firstAccount.IdKeys.Ed25519.PublicKey.Equal(unpickledAccount.IdKeys.Ed25519.PublicKey) { + t.Fatal("IdentityKeys Ed25519 public unequal") + } - // Ensure that all of the keys are unpublished right now - otks, err := firstAccount.OneTimeKeys() - assert.NoError(t, err) - assert.Len(t, otks, 2) - assert.Len(t, firstAccount.FallbackKeyUnpublished(), 1) - - // Now, publish the key and make sure that they are published + if len(firstAccount.OneTimeKeys()) != 2 { + t.Fatal("should get 2 unpublished oneTimeKeys") + } + if len(firstAccount.FallbackKeyUnpublished()) == 0 { + t.Fatal("should get fallbackKey") + } firstAccount.MarkKeysAsPublished() - - assert.Len(t, firstAccount.FallbackKeyUnpublished(), 0) - assert.Len(t, firstAccount.FallbackKey(), 1) - otks, err = firstAccount.OneTimeKeys() - assert.NoError(t, err) - assert.Len(t, otks, 0) + if len(firstAccount.FallbackKey()) == 0 { + t.Fatal("should get fallbackKey") + } + if len(firstAccount.FallbackKeyUnpublished()) != 0 { + t.Fatal("should get no fallbackKey") + } + if len(firstAccount.OneTimeKeys()) != 0 { + t.Fatal("should get no oneTimeKeys") + } } func TestAccountPickleJSON(t *testing.T) { @@ -68,49 +104,109 @@ func TestAccountPickleJSON(t *testing.T) { pickledData := []byte("6POkBWwbNl20fwvZWsOu0jgbHy4jkA5h0Ji+XCag59+ifWIRPDrqtgQi9HmkLiSF6wUhhYaV4S73WM+Hh+dlCuZRuXhTQr8yGPTifjcjq8birdAhObbEqHrYEdqaQkrgBLr/rlS5sibXeDqbkhVu4LslvootU9DkcCbd4b/0Flh7iugxqkcCs5GDndTEx9IzTVJzmK82Y0Q1Z1Z9Vuc2Iw746PtBJLtZjite6fSMp2NigPX/ZWWJ3OnwcJo0Vvjy8hgptZEWkamOHdWbUtelbHyjDIZlvxOC25D3rFif0zzPkF9qdpBPqVCWPPzGFmgnqKau6CHrnPfq7GLsM3BrprD7sHN1Js28ex14gXQPjBT7KTUo6H0e4gQMTMRp4qb8btNXDeId8xIFIElTh2SXZBTDmSq/ziVNJinEvYV8mGPvJZjDQQU+SyoS/HZ8uMc41tH0BOGDbFMHbfLMiz61E429gOrx2klu5lqyoyet7//HKi0ed5w2dQ") account, err := account.AccountFromJSONPickled(pickledData, key) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } expectedJSON := `{"ed25519":"qWvNB6Ztov5/AOsP073op0O32KJ8/tgSNarT7MaYgQE","curve25519":"TFUB6M6zwgyWhBEp2m1aUodl2AsnsrIuBr8l9AvwGS8"}` jsonData, err := account.IdentityKeysJSON() - assert.NoError(t, err) - assert.Equal(t, expectedJSON, string(jsonData)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(jsonData, []byte(expectedJSON)) { + t.Fatalf("Expected '%s' but got '%s'", expectedJSON, jsonData) + } } func TestSessions(t *testing.T) { - aliceAccount, err := account.NewAccount() - assert.NoError(t, err) - err = aliceAccount.GenOneTimeKeys(5) - assert.NoError(t, err) - bobAccount, err := account.NewAccount() - assert.NoError(t, err) - err = bobAccount.GenOneTimeKeys(5) - assert.NoError(t, err) + aliceAccount, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = aliceAccount.GenOneTimeKeys(nil, 5) + if err != nil { + t.Fatal(err) + } + bobAccount, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = bobAccount.GenOneTimeKeys(nil, 5) + if err != nil { + t.Fatal(err) + } aliceSession, err := aliceAccount.NewOutboundSession(bobAccount.IdKeys.Curve25519.B64Encoded(), bobAccount.OTKeys[2].Key.B64Encoded()) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } plaintext := []byte("test message") - msgType, crypttext, err := aliceSession.Encrypt(plaintext) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType) + msgType, crypttext, err := aliceSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } - bobSession, err := bobAccount.NewInboundSession(string(crypttext)) - assert.NoError(t, err) - decodedText, err := bobSession.Decrypt(string(crypttext), msgType) - assert.NoError(t, err) - assert.Equal(t, plaintext, decodedText) + bobSession, err := bobAccount.NewInboundSession(nil, crypttext) + if err != nil { + t.Fatal(err) + } + decodedText, err := bobSession.Decrypt(crypttext, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decodedText) { + t.Fatalf("expected '%s' but got '%s'", string(plaintext), string(decodedText)) + } } func TestAccountPickle(t *testing.T) { pickleKey := []byte("secret_key") account, err := account.AccountFromPickled(pickledDataFromLibOlm, pickleKey) - assert.NoError(t, err) - assert.Equal(t, expectedEd25519KeyPairPickleLibOLM, account.IdKeys.Ed25519) - assert.Equal(t, expectedCurve25519KeyPairPickleLibOLM, account.IdKeys.Curve25519) - assert.EqualValues(t, 42, account.NextOneTimeKeyID) - assert.Equal(t, account.OTKeys, expectedOTKeysPickleLibOLM) - assert.EqualValues(t, 0, account.NumFallbackKeys) + if err != nil { + t.Fatal(err) + } + if !expectedEd25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Ed25519.PrivateKey) { + t.Fatal("keys not equal") + } + if !expectedEd25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Ed25519.PublicKey) { + t.Fatal("keys not equal") + } + if !expectedCurve25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Curve25519.PrivateKey) { + t.Fatal("keys not equal") + } + if !expectedCurve25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Curve25519.PublicKey) { + t.Fatal("keys not equal") + } + if account.NextOneTimeKeyID != 42 { + t.Fatal("wrong next otKey id") + } + if len(account.OTKeys) != len(expectedOTKeysPickleLibOLM) { + t.Fatal("wrong number of otKeys") + } + if account.NumFallbackKeys != 0 { + t.Fatal("fallback keys set but not in pickle") + } + for curIndex, curValue := range account.OTKeys { + curExpected := expectedOTKeysPickleLibOLM[curIndex] + if curExpected.ID != curValue.ID { + t.Fatal("OTKey id not correct") + } + if !curExpected.Key.PublicKey.Equal(curValue.Key.PublicKey) { + t.Fatal("OTKey public key not correct") + } + if !curExpected.Key.PrivateKey.Equal(curValue.Key.PrivateKey) { + t.Fatal("OTKey private key not correct") + } + } targetPickled, err := account.Pickle(pickleKey) - assert.NoError(t, err) - assert.Equal(t, pickledDataFromLibOlm, targetPickled) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(targetPickled, pickledDataFromLibOlm) { + t.Fatal("repickled value does not equal given value") + } } func TestOldAccountPickle(t *testing.T) { @@ -121,213 +217,356 @@ func TestOldAccountPickle(t *testing.T) { "K/A/8TOu9iK2hDFszy6xETiousHnHgh2ZGbRUh4pQx+YMm8ZdNZeRnwFGLnrWyf9" + "O5TmXua1FcU") pickleKey := []byte("") - account, err := account.NewAccount() - assert.NoError(t, err) + account, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } err = account.Unpickle(pickled, pickleKey) - assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) + if err == nil { + t.Fatal("expected error") + } else { + if !errors.Is(err, goolm.ErrBadVersion) { + t.Fatal(err) + } + } } func TestLoopback(t *testing.T) { - accountA, err := account.NewAccount() - assert.NoError(t, err) + accountA, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } - accountB, err := account.NewAccount() - assert.NoError(t, err) - err = accountB.GenOneTimeKeys(42) - assert.NoError(t, err) + accountB, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = accountB.GenOneTimeKeys(nil, 42) + if err != nil { + t.Fatal(err) + } aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType) + msgType, message1, err := aliceSession.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } - bobSession, err := accountB.NewInboundSession(string(message1)) - assert.NoError(t, err) + bobSession, err := accountB.NewInboundSession(nil, message1) + if err != nil { + t.Fatal(err) + } // Check that the inbound session matches the message it was created from. - sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) - assert.NoError(t, err) - assert.True(t, sessionIsOK, "session was not detected to be valid") - + sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session was not detected to be valid") + } // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) - assert.NoError(t, err) - assert.True(t, sessionIsOK, "session is sad to be not from a but it should") - + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session is sad to be not from a but it should") + } // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) - assert.NoError(t, err) - assert.False(t, sessionIsOK, "session is sad to be from b but is from a") - + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session is sad to be from b but is from a") + } // Check that we can decrypt the message. - decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage) + decryptedMessage, err := bobSession.Decrypt(message1, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } - msgTyp2, message2, err := bobSession.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypeMsg, msgTyp2) + msgTyp2, message2, err := bobSession.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgTyp2 == id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } - decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage2) + decryptedMessage2, err := aliceSession.Decrypt(message2, msgTyp2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage2, plainText) { + t.Fatal("messages are not the same") + } //decrypting again should fail, as the chain moved on - _, err = aliceSession.Decrypt(string(message2), msgTyp2) - assert.Error(t, err) - assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound) + _, err = aliceSession.Decrypt(message2, msgTyp2) + if err == nil { + t.Fatal("expected error") + } //compare sessionIDs - assert.Equal(t, aliceSession.ID(), bobSession.ID()) + if aliceSession.ID() != bobSession.ID() { + t.Fatal("sessionIDs are not equal") + } } func TestMoreMessages(t *testing.T) { - accountA, err := account.NewAccount() - assert.NoError(t, err) + accountA, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } - accountB, err := account.NewAccount() - assert.NoError(t, err) - err = accountB.GenOneTimeKeys(42) - assert.NoError(t, err) + accountB, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = accountB.GenOneTimeKeys(nil, 42) + if err != nil { + t.Fatal(err) + } aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType) + msgType, message1, err := aliceSession.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } - bobSession, err := accountB.NewInboundSession(string(message1)) - assert.NoError(t, err) - decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage) + bobSession, err := accountB.NewInboundSession(nil, message1) + if err != nil { + t.Fatal(err) + } + decryptedMessage, err := bobSession.Decrypt(message1, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } for i := 0; i < 8; i++ { //alice sends, bob reveices - msgType, message, err := aliceSession.Encrypt(plainText) - assert.NoError(t, err) + msgType, message, err := aliceSession.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } if i == 0 { //The first time should still be a preKeyMessage as bob has not yet send a message to alice - assert.Equal(t, id.OlmMsgTypePreKey, msgType) + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } } else { - assert.Equal(t, id.OlmMsgTypeMsg, msgType) + if msgType == id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + } + decryptedMessage, err := bobSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") } - decryptedMessage, err := bobSession.Decrypt(string(message), msgType) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage) - //now bob sends, alice receives - msgType, message, err = bobSession.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypeMsg, msgType) - - decryptedMessage, err = aliceSession.Decrypt(string(message), msgType) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage) + msgType, message, err = bobSession.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgType == id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + decryptedMessage, err = aliceSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } } } func TestFallbackKey(t *testing.T) { - accountA, err := account.NewAccount() - assert.NoError(t, err) + accountA, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } - accountB, err := account.NewAccount() - assert.NoError(t, err) - err = accountB.GenFallbackKey() - assert.NoError(t, err) + accountB, err := account.NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = accountB.GenFallbackKey(nil) + if err != nil { + t.Fatal(err) + } fallBackKeys := accountB.FallbackKeyUnpublished() var fallbackKey id.Curve25519 for _, fbKey := range fallBackKeys { fallbackKey = fbKey } aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType) + msgType, message1, err := aliceSession.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } - bobSession, err := accountB.NewInboundSession(string(message1)) - assert.NoError(t, err) + bobSession, err := accountB.NewInboundSession(nil, message1) + if err != nil { + t.Fatal(err) + } // Check that the inbound session matches the message it was created from. - sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) - assert.NoError(t, err) - assert.True(t, sessionIsOK, "session was not detected to be valid") - + sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session was not detected to be valid") + } // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) - assert.NoError(t, err) - assert.True(t, sessionIsOK, "session is sad to be not from a but it should") - + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session is sad to be not from a but it should") + } // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) - assert.NoError(t, err) - assert.False(t, sessionIsOK, "session is sad to be from b but is from a") - + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session is sad to be from b but is from a") + } // Check that we can decrypt the message. - decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage) + decryptedMessage, err := bobSession.Decrypt(message1, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } // create a new fallback key for B (the old fallback should still be usable) - err = accountB.GenFallbackKey() - assert.NoError(t, err) + err = accountB.GenFallbackKey(nil) + if err != nil { + t.Fatal(err) + } // start another session and encrypt a message aliceSession2, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - assert.NoError(t, err) - - msgType2, message2, err := aliceSession2.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType2) + if err != nil { + t.Fatal(err) + } + msgType2, message2, err := aliceSession2.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgType2 != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } // bobSession should not be valid for the message2 // Check that the inbound session matches the message it was created from. - sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2)) - assert.NoError(t, err) - assert.False(t, sessionIsOK, "session was detected to be valid but should not") - - bobSession2, err := accountB.NewInboundSession(string(message2)) - assert.NoError(t, err) + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(nil, message2) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session was detected to be valid but should not") + } + bobSession2, err := accountB.NewInboundSession(nil, message2) + if err != nil { + t.Fatal(err) + } // Check that the inbound session matches the message it was created from. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2)) - assert.NoError(t, err) - assert.True(t, sessionIsOK, "session was not detected to be valid") - + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(nil, message2) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session was not detected to be valid") + } // Check that the inbound session matches the key this message is supposed to be from. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2)) - assert.NoError(t, err) - assert.True(t, sessionIsOK, "session is sad to be not from a but it should") - + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&aIDKey, message2) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session is sad to be not from a but it should") + } // Check that the inbound session isn't from a different user. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2)) - assert.NoError(t, err) - assert.False(t, sessionIsOK, "session is sad to be from b but is from a") - + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&bIDKey, message2) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session is sad to be from b but is from a") + } // Check that we can decrypt the message. - decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2) - assert.NoError(t, err) - assert.Equal(t, plainText, decryptedMessage2) + decryptedMessage2, err := bobSession2.Decrypt(message2, msgType2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage2, plainText) { + t.Fatal("messages are not the same") + } //Forget the old fallback key -- creating a new session should fail now accountB.ForgetOldFallbackKey() // start another session and encrypt a message aliceSession3, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - assert.NoError(t, err) - msgType3, message3, err := aliceSession3.Encrypt(plainText) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType3) - _, err = accountB.NewInboundSession(string(message3)) - assert.ErrorIs(t, err, olm.ErrBadMessageKeyID) + if err != nil { + t.Fatal(err) + } + msgType3, message3, err := aliceSession3.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + if msgType3 != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + _, err = accountB.NewInboundSession(nil, message3) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, goolm.ErrBadMessageKeyID) { + t.Fatal(err) + } } func TestOldV3AccountPickle(t *testing.T) { @@ -343,23 +582,33 @@ func TestOldV3AccountPickle(t *testing.T) { expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}") account, err := account.AccountFromPickled(pickledData, pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } fallbackJSON, err := account.FallbackKeyJSON() - assert.NoError(t, err) - assert.Equal(t, expectedFallbackJSON, fallbackJSON) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fallbackJSON, expectedFallbackJSON) { + t.Fatalf("expected not as result:\n%s\n%s\n", expectedFallbackJSON, fallbackJSON) + } fallbackJSONUnpublished, err := account.FallbackKeyUnpublishedJSON() - assert.NoError(t, err) - assert.Equal(t, expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fallbackJSONUnpublished, expectedUnpublishedFallbackJSON) { + t.Fatalf("expected not as result:\n%s\n%s\n", expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) + } } func TestAccountSign(t *testing.T) { - accountA, err := account.NewAccount() - assert.NoError(t, err) + accountA, err := account.NewAccount(nil) + require.NoError(t, err) plainText := []byte("Hello, World") signatureB64, err := accountA.Sign(plainText) - assert.NoError(t, err) + require.NoError(t, err) signature, err := base64.RawStdEncoding.DecodeString(string(signatureB64)) - assert.NoError(t, err) + require.NoError(t, err) verified, err := signatures.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature) assert.NoError(t, err) diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go deleted file mode 100644 index ec392d7e..00000000 --- a/crypto/goolm/account/register.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package account - -import ( - "maunium.net/go/mautrix/crypto/olm" -) - -func Register() { - olm.InitNewAccount = func() (olm.Account, error) { - return NewAccount() - } - olm.InitBlankAccount = func() olm.Account { - return &Account{} - } - olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { - return AccountFromPickled(pickled, key) - } -} diff --git a/crypto/goolm/aessha2/aessha2.go b/crypto/goolm/aessha2/aessha2.go deleted file mode 100644 index 42d9811b..00000000 --- a/crypto/goolm/aessha2/aessha2.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// Package aessha2 implements the m.megolm.v1.aes-sha2 encryption algorithm -// described in [Section 10.12.4.3] in the Spec -// -// [Section 10.12.4.3]: https://spec.matrix.org/v1.12/client-server-api/#mmegolmv1aes-sha2 -package aessha2 - -import ( - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "io" - - "golang.org/x/crypto/hkdf" - - "maunium.net/go/mautrix/crypto/aescbc" -) - -type AESSHA2 struct { - aesKey, hmacKey, iv []byte -} - -func NewAESSHA2(secret, info []byte) (AESSHA2, error) { - kdf := hkdf.New(sha256.New, secret, nil, info) - keymatter := make([]byte, 80) - _, err := io.ReadFull(kdf, keymatter) - return AESSHA2{ - keymatter[:32], // AES Key - keymatter[32:64], // HMAC Key - keymatter[64:], // IV - }, err -} - -func (a *AESSHA2) Encrypt(plaintext []byte) ([]byte, error) { - return aescbc.Encrypt(a.aesKey, a.iv, plaintext) -} - -func (a *AESSHA2) Decrypt(ciphertext []byte) ([]byte, error) { - return aescbc.Decrypt(a.aesKey, a.iv, ciphertext) -} - -func (a *AESSHA2) MAC(ciphertext []byte) ([]byte, error) { - hash := hmac.New(sha256.New, a.hmacKey) - _, err := hash.Write(ciphertext) - return hash.Sum(nil), err -} - -func (a *AESSHA2) VerifyMAC(ciphertext, theirMAC []byte) (bool, error) { - if mac, err := a.MAC(ciphertext); err != nil { - return false, err - } else { - return subtle.ConstantTimeCompare(mac[:len(theirMAC)], theirMAC) == 1, nil - } -} diff --git a/crypto/goolm/aessha2/aessha2_test.go b/crypto/goolm/aessha2/aessha2_test.go deleted file mode 100644 index b2cfe8aa..00000000 --- a/crypto/goolm/aessha2/aessha2_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package aessha2_test - -import ( - "crypto/aes" - "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/aessha2" -) - -func TestCipherAESSha256(t *testing.T) { - key := []byte("test key") - cipher, err := aessha2.NewAESSHA2(key, []byte("testKDFinfo")) - assert.NoError(t, err) - message := []byte("this is a random message for testing the implementation") - //increase to next block size - for len(message)%aes.BlockSize != 0 { - message = append(message, []byte("-")...) - } - encrypted, err := cipher.Encrypt([]byte(message)) - assert.NoError(t, err) - mac, err := cipher.MAC(encrypted) - assert.NoError(t, err) - - verified, err := cipher.VerifyMAC(encrypted, mac[:8]) - assert.NoError(t, err) - assert.True(t, verified, "signature verification failed") - - resultPlainText, err := cipher.Decrypt(encrypted) - assert.NoError(t, err) - assert.Equal(t, message, resultPlainText) -} diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/base64.go similarity index 62% rename from crypto/goolm/goolmbase64/base64.go rename to crypto/goolm/base64.go index 58ee26f7..229008cf 100644 --- a/crypto/goolm/goolmbase64/base64.go +++ b/crypto/goolm/base64.go @@ -1,12 +1,11 @@ -package goolmbase64 +package goolm import ( "encoding/base64" ) -// These methods should only be used for raw byte operations, never with string conversion - -func Decode(input []byte) ([]byte, error) { +// Deprecated: base64.RawStdEncoding should be used directly +func Base64Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) if err != nil { @@ -15,7 +14,8 @@ func Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } -func Encode(input []byte) []byte { +// Deprecated: base64.RawStdEncoding should be used directly +func Base64Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) return encoded diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go new file mode 100644 index 00000000..2d2d58d5 --- /dev/null +++ b/crypto/goolm/cipher/aes_sha256.go @@ -0,0 +1,98 @@ +package cipher + +import ( + "bytes" + "crypto/aes" + "io" + + "maunium.net/go/mautrix/crypto/aescbc" + "maunium.net/go/mautrix/crypto/goolm/crypto" +) + +// derivedAESKeys stores the derived keys for the AESSHA256 cipher +type derivedAESKeys struct { + key []byte + hmacKey []byte + iv []byte +} + +// deriveAESKeys derives three keys for the AESSHA256 cipher +func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { + hkdf := crypto.HKDFSHA256(key, nil, kdfInfo) + keys := &derivedAESKeys{ + key: make([]byte, 32), + hmacKey: make([]byte, 32), + iv: make([]byte, 16), + } + if _, err := io.ReadFull(hkdf, keys.key); err != nil { + return nil, err + } + if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil { + return nil, err + } + if _, err := io.ReadFull(hkdf, keys.iv); err != nil { + return nil, err + } + return keys, nil +} + +// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. +func AESSha512BlockSize() int { + return aes.BlockSize +} + +// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. +type AESSHA256 struct { + kdfInfo []byte +} + +// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo). +func NewAESSHA256(kdfInfo []byte) *AESSHA256 { + return &AESSHA256{ + kdfInfo: kdfInfo, + } +} + +// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). +func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { + keys, err := deriveAESKeys(c.kdfInfo, key) + if err != nil { + return nil, err + } + ciphertext, err = aescbc.Encrypt(keys.key, keys.iv, plaintext) + if err != nil { + return nil, err + } + return ciphertext, nil +} + +// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). +func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { + keys, err := deriveAESKeys(c.kdfInfo, key) + if err != nil { + return nil, err + } + plaintext, err = aescbc.Decrypt(keys.key, keys.iv, ciphertext) + if err != nil { + return nil, err + } + return plaintext, nil +} + +// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). +func (c AESSHA256) MAC(key, message []byte) ([]byte, error) { + keys, err := deriveAESKeys(c.kdfInfo, key) + if err != nil { + return nil, err + } + return crypto.HMACSHA256(keys.hmacKey, message), nil +} + +// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). +func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) { + mac, err := c.MAC(key, message) + if err != nil { + return false, err + } + return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil +} diff --git a/crypto/goolm/cipher/aes_sha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go new file mode 100644 index 00000000..d2f49cb1 --- /dev/null +++ b/crypto/goolm/cipher/aes_sha256_test.go @@ -0,0 +1,83 @@ +package cipher + +import ( + "bytes" + "crypto/aes" + "testing" +) + +func TestDeriveAESKeys(t *testing.T) { + kdfInfo := []byte("test") + key := []byte("test key") + derivedKeys, err := deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + derivedKeys2, err := deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + //derivedKeys and derivedKeys2 should be identical + if !bytes.Equal(derivedKeys.key, derivedKeys2.key) || + !bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || + !bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { + t.Fail() + } + //changing kdfInfo + kdfInfo = []byte("other kdf") + derivedKeys2, err = deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + //derivedKeys and derivedKeys2 should now be different + if bytes.Equal(derivedKeys.key, derivedKeys2.key) || + bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || + bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { + t.Fail() + } + //changing key + key = []byte("other test key") + derivedKeys, err = deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + //derivedKeys and derivedKeys2 should now be different + if bytes.Equal(derivedKeys.key, derivedKeys2.key) || + bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || + bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { + t.Fail() + } +} + +func TestCipherAESSha256(t *testing.T) { + key := []byte("test key") + cipher := NewAESSHA256([]byte("testKDFinfo")) + message := []byte("this is a random message for testing the implementation") + //increase to next block size + for len(message)%aes.BlockSize != 0 { + message = append(message, []byte("-")...) + } + encrypted, err := cipher.Encrypt(key, []byte(message)) + if err != nil { + t.Fatal(err) + } + mac, err := cipher.MAC(key, encrypted) + if err != nil { + t.Fatal(err) + } + + verified, err := cipher.Verify(key, encrypted, mac[:8]) + if err != nil { + t.Fatal(err) + } + if !verified { + t.Fatal("signature verification failed") + } + resultPlainText, err := cipher.Decrypt(key, encrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(message, resultPlainText) { + t.Fail() + } +} diff --git a/crypto/goolm/cipher/cipher.go b/crypto/goolm/cipher/cipher.go new file mode 100644 index 00000000..43580b0b --- /dev/null +++ b/crypto/goolm/cipher/cipher.go @@ -0,0 +1,18 @@ +// Package cipher provides the methods and structs to do encryptions for +// olm/megolm. +package cipher + +// Cipher defines a valid cipher. +type Cipher interface { + // Encrypt encrypts the plaintext. + Encrypt(key, plaintext []byte) (ciphertext []byte, err error) + + // Decrypt decrypts the ciphertext. + Decrypt(key, ciphertext []byte) (plaintext []byte, err error) + + //MAC returns the MAC of the message calculated with the key. + MAC(key, message []byte) ([]byte, error) + + //Verify checks the MAC of the message calculated with the key against the givenMAC. + Verify(key, message, givenMAC []byte) (bool, error) +} diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go new file mode 100644 index 00000000..670ff6ff --- /dev/null +++ b/crypto/goolm/cipher/pickle.go @@ -0,0 +1,58 @@ +package cipher + +import ( + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" +) + +const ( + kdfPickle = "Pickle" //used to derive the keys for encryption + pickleMACLength = 8 +) + +// PickleBlockSize returns the blocksize of the used cipher. +func PickleBlockSize() int { + return AESSha512BlockSize() +} + +// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. +func Pickle(key, input []byte) ([]byte, error) { + pickleCipher := NewAESSHA256([]byte(kdfPickle)) + ciphertext, err := pickleCipher.Encrypt(key, input) + if err != nil { + return nil, err + } + mac, err := pickleCipher.MAC(key, ciphertext) + if err != nil { + return nil, err + } + ciphertext = append(ciphertext, mac[:pickleMACLength]...) + encoded := goolm.Base64Encode(ciphertext) + return encoded, nil +} + +// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. +func Unpickle(key, input []byte) ([]byte, error) { + pickleCipher := NewAESSHA256([]byte(kdfPickle)) + ciphertext, err := goolm.Base64Decode(input) + if err != nil { + return nil, err + } + //remove mac and check + verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]) + if err != nil { + return nil, err + } + if !verified { + return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC) + } + //Set to next block size + targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) + copy(targetCipherText, ciphertext) + plaintext, err := pickleCipher.Decrypt(key, targetCipherText) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go new file mode 100644 index 00000000..b47bf3ea --- /dev/null +++ b/crypto/goolm/cipher/pickle_test.go @@ -0,0 +1,33 @@ +package cipher_test + +import ( + "bytes" + "crypto/aes" + "testing" + + "maunium.net/go/mautrix/crypto/goolm/cipher" +) + +func TestEncoding(t *testing.T) { + key := []byte("test key") + input := []byte("test") + //pad marshaled to get block size + toEncrypt := input + if len(input)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(input)%aes.BlockSize + toEncrypt = make([]byte, len(input)+padding) + copy(toEncrypt, input) + } + encoded, err := cipher.Pickle(key, toEncrypt) + if err != nil { + t.Fatal(err) + } + + decoded, err := cipher.Unpickle(key, encoded) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decoded, toEncrypt) { + t.Fatalf("Expected '%s' but got '%s'", toEncrypt, decoded) + } +} diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 6e42d886..125e1bfd 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -1,49 +1,67 @@ package crypto import ( + "bytes" "crypto/rand" - "crypto/subtle" "encoding/base64" + "fmt" + "io" "golang.org/x/crypto/curve25519" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) const ( - Curve25519PrivateKeyLength = curve25519.ScalarSize //The length of the private key. - Curve25519PublicKeyLength = 32 + Curve25519KeyLength = curve25519.ScalarSize //The length of the private key. + curve25519PubKeyLength = 32 ) -// Curve25519KeyPair stores both parts of a curve25519 key. -type Curve25519KeyPair struct { - PrivateKey Curve25519PrivateKey `json:"private,omitempty"` - PublicKey Curve25519PublicKey `json:"public,omitempty"` -} - -// Curve25519GenerateKey creates a new curve25519 key pair. -func Curve25519GenerateKey() (Curve25519KeyPair, error) { - privateKeyByte := make([]byte, Curve25519PrivateKeyLength) - if _, err := rand.Read(privateKeyByte); err != nil { - return Curve25519KeyPair{}, err +// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand. +func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) { + privateKeyByte := make([]byte, Curve25519KeyLength) + if reader == nil { + _, err := rand.Read(privateKeyByte) + if err != nil { + return Curve25519KeyPair{}, err + } + } else { + _, err := reader.Read(privateKeyByte) + if err != nil { + return Curve25519KeyPair{}, err + } } privateKey := Curve25519PrivateKey(privateKeyByte) + publicKey, err := privateKey.PubKey() + if err != nil { + return Curve25519KeyPair{}, err + } return Curve25519KeyPair{ PrivateKey: Curve25519PrivateKey(privateKey), PublicKey: Curve25519PublicKey(publicKey), - }, err + }, nil } // Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given. func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) { publicKey, err := private.PubKey() + if err != nil { + return Curve25519KeyPair{}, err + } return Curve25519KeyPair{ PrivateKey: private, PublicKey: Curve25519PublicKey(publicKey), - }, err + }, nil +} + +// Curve25519KeyPair stores both parts of a curve25519 key. +type Curve25519KeyPair struct { + PrivateKey Curve25519PrivateKey `json:"private,omitempty"` + PublicKey Curve25519PublicKey `json:"public,omitempty"` } // B64Encoded returns a base64 encoded string of the public key. @@ -53,30 +71,53 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { // SharedSecret returns the shared secret between the key pair and the given public key. func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { - // Note: the standard library checks that the output is non-zero return c.PrivateKey.SharedSecret(pubKey) } -// PickleLibOlm pickles the key pair into the encoder. -func (c Curve25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { - c.PublicKey.PickleLibOlm(encoder) - if len(c.PrivateKey) == Curve25519PrivateKeyLength { - encoder.Write(c.PrivateKey) - } else { - encoder.WriteEmptyBytes(Curve25519PrivateKeyLength) +// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort) } + written, err := c.PublicKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle curve25519 key pair: %w", err) + } + if len(c.PrivateKey) != Curve25519KeyLength { + written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:]) + } else { + written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + } + return written, nil } // UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. -func (c *Curve25519KeyPair) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - if err := c.PublicKey.UnpickleLibOlm(decoder); err != nil { - return err - } else if privKey, err := decoder.ReadBytes(Curve25519PrivateKeyLength); err != nil { - return err - } else { - c.PrivateKey = privKey - return nil +func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { + //unpickle PubKey + read, err := c.PublicKey.UnpickleLibOlm(value) + if err != nil { + return 0, err } + //unpickle PrivateKey + privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength) + if err != nil { + return read, err + } + c.PrivateKey = privKey + return read + readPriv, nil +} + +// PickleLen returns the number of bytes the pickled key pair will have. +func (c Curve25519KeyPair) PickleLen() int { + lenPublic := c.PublicKey.PickleLen() + var lenPrivate int + if len(c.PrivateKey) != Curve25519KeyLength { + lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength)) + } else { + lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) + } + return lenPublic + lenPrivate } // Curve25519PrivateKey represents the private key for curve25519 usage @@ -84,12 +125,16 @@ type Curve25519PrivateKey []byte // Equal compares the private key to the given private key. func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool { - return subtle.ConstantTimeCompare(c, x) == 1 + return bytes.Equal(c, x) } // PubKey returns the public key derived from the private key. func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) { - return curve25519.X25519(c, curve25519.Basepoint) + publicKey, err := curve25519.X25519(c, curve25519.Basepoint) + if err != nil { + return nil, err + } + return publicKey, nil } // SharedSecret returns the shared secret between the private key and the given public key. @@ -102,7 +147,7 @@ type Curve25519PublicKey []byte // Equal compares the public key to the given public key. func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool { - return subtle.ConstantTimeCompare(c, x) == 1 + return bytes.Equal(c, x) } // B64Encoded returns a base64 encoded string of the public key. @@ -110,18 +155,32 @@ func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { return id.Curve25519(base64.RawStdEncoding.EncodeToString(c)) } -// PickleLibOlm pickles the public key into the encoder. -func (c Curve25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { - if len(c) == Curve25519PublicKeyLength { - encoder.Write(c) - } else { - encoder.WriteEmptyBytes(Curve25519PublicKeyLength) +// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort) } + if len(c) != curve25519PubKeyLength { + return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil + } + return libolmpickle.PickleBytes(c, target), nil } // UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. -func (c *Curve25519PublicKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - pubkey, err := decoder.ReadBytes(Curve25519PublicKeyLength) - *c = pubkey - return err +func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { + unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength) + if err != nil { + return 0, err + } + *c = unpickled + return readBytes, nil +} + +// PickleLen returns the number of bytes the pickled public key will have. +func (c Curve25519PublicKey) PickleLen() int { + if len(c) != curve25519PubKeyLength { + return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength)) + } + return libolmpickle.PickleBytesLen(c) } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index 2550f15e..f7df5edc 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -1,32 +1,39 @@ package crypto_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) -const curve25519KeyPairPickleLength = crypto.Curve25519PublicKeyLength + // Public Key - crypto.Curve25519PrivateKeyLength // Private Key - func TestCurve25519(t *testing.T) { - firstKeypair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) - secondKeypair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) + firstKeypair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + secondKeypair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } sharedSecretFromFirst, err := firstKeypair.SharedSecret(secondKeypair.PublicKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } sharedSecretFromSecond, err := secondKeypair.SharedSecret(firstKeypair.PublicKey) - assert.NoError(t, err) - assert.Equal(t, sharedSecretFromFirst, sharedSecretFromSecond, "shared secret not equal") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) { + t.Fatal("shared secret not equal") + } fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) - assert.NoError(t, err) - assert.Equal(t, fromPrivate, firstKeypair) - _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength)) - assert.Error(t, err) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fromPrivate.PublicKey, firstKeypair.PublicKey) { + t.Fatal("public keys not equal") + } } func TestCurve25519Case1(t *testing.T) { @@ -69,57 +76,112 @@ func TestCurve25519Case1(t *testing.T) { PublicKey: bobPublic, } agreementFromAlice, err := aliceKeyPair.SharedSecret(bobKeyPair.PublicKey) - assert.NoError(t, err) - assert.Equal(t, expectedAgreement, agreementFromAlice, "expected agreement does not match agreement from Alice's view") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(agreementFromAlice, expectedAgreement) { + t.Fatal("expected agreement does not match agreement from Alice's view") + } agreementFromBob, err := bobKeyPair.SharedSecret(aliceKeyPair.PublicKey) - assert.NoError(t, err) - assert.Equal(t, expectedAgreement, agreementFromBob, "expected agreement does not match agreement from Bob's view") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(agreementFromBob, expectedAgreement) { + t.Fatal("expected agreement does not match agreement from Bob's view") + } } func TestCurve25519Pickle(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) - - encoder := libolmpickle.NewEncoder() - keyPair.PickleLibOlm(encoder) - assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) + keyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } unpickledKeyPair := crypto.Curve25519KeyPair{} - err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) - assert.NoError(t, err) - assert.Equal(t, keyPair, unpickledKeyPair) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } } func TestCurve25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) - + keyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } //Remove privateKey keyPair.PrivateKey = nil - - encoder := libolmpickle.NewEncoder() - keyPair.PickleLibOlm(encoder) - assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) - + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } unpickledKeyPair := crypto.Curve25519KeyPair{} - err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) - assert.NoError(t, err) - assert.Equal(t, keyPair, unpickledKeyPair) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } } func TestCurve25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) + keyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } //Remove public keyPair.PublicKey = nil - encoder := libolmpickle.NewEncoder() - keyPair.PickleLibOlm(encoder) - assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } unpickledKeyPair := crypto.Curve25519KeyPair{} - err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) - assert.NoError(t, err) - assert.Equal(t, keyPair, unpickledKeyPair) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } } diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index a3345ba9..f0c56297 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -1,24 +1,30 @@ package crypto import ( + "crypto/ed25519" "encoding/base64" + "fmt" + "io" - "maunium.net/go/mautrix/crypto/ed25519" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) const ( - Ed25519SignatureSize = ed25519.SignatureSize //The length of a signature + ED25519SignatureSize = ed25519.SignatureSize //The length of a signature ) -// Ed25519GenerateKey creates a new ed25519 key pair. -func Ed25519GenerateKey() (Ed25519KeyPair, error) { - publicKey, privateKey, err := ed25519.GenerateKey(nil) +// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand. +func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) { + publicKey, privateKey, err := ed25519.GenerateKey(reader) + if err != nil { + return Ed25519KeyPair{}, err + } return Ed25519KeyPair{ PrivateKey: Ed25519PrivateKey(privateKey), PublicKey: Ed25519PublicKey(publicKey), - }, err + }, nil } // Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given. @@ -50,7 +56,7 @@ func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { } // Sign returns the signature for the message. -func (c Ed25519KeyPair) Sign(message []byte) ([]byte, error) { +func (c Ed25519KeyPair) Sign(message []byte) []byte { return c.PrivateKey.Sign(message) } @@ -59,26 +65,51 @@ func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { return c.PublicKey.Verify(message, givenSignature) } -// PickleLibOlm pickles the key pair into the encoder. -func (c Ed25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { - c.PublicKey.PickleLibOlm(encoder) - if len(c.PrivateKey) == ed25519.PrivateKeySize { - encoder.Write(c.PrivateKey) - } else { - encoder.WriteEmptyBytes(ed25519.PrivateKeySize) +// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort) } + written, err := c.PublicKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle ed25519 key pair: %w", err) + } + + if len(c.PrivateKey) != ed25519.PrivateKeySize { + written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:]) + } else { + written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + } + return written, nil } -// UnpickleLibOlm unpickles the unencryted value and populates the key pair accordingly. -func (c *Ed25519KeyPair) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - if err := c.PublicKey.UnpickleLibOlm(decoder); err != nil { - return err - } else if privKey, err := decoder.ReadBytes(ed25519.PrivateKeySize); err != nil { - return err - } else { - c.PrivateKey = privKey - return nil +// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. +func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { + //unpickle PubKey + read, err := c.PublicKey.UnpickleLibOlm(value) + if err != nil { + return 0, err } + //unpickle PrivateKey + privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize) + if err != nil { + return read, err + } + c.PrivateKey = privKey + return read + readPriv, nil +} + +// PickleLen returns the number of bytes the pickled key pair will have. +func (c Ed25519KeyPair) PickleLen() int { + lenPublic := c.PublicKey.PickleLen() + var lenPrivate int + if len(c.PrivateKey) != ed25519.PrivateKeySize { + lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize)) + } else { + lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) + } + return lenPublic + lenPrivate } // Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper. @@ -92,12 +123,12 @@ func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool { // PubKey returns the public key derived from the private key. func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey { publicKey := ed25519.PrivateKey(c).Public() - return Ed25519PublicKey(publicKey.([]byte)) + return Ed25519PublicKey(publicKey.(ed25519.PublicKey)) } // Sign returns the signature for the message. -func (c Ed25519PrivateKey) Sign(message []byte) ([]byte, error) { - return ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{}) +func (c Ed25519PrivateKey) Sign(message []byte) []byte { + return ed25519.Sign(ed25519.PrivateKey(c), message) } // Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper. @@ -118,19 +149,32 @@ func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature) } -// PickleLibOlm pickles the public key into the encoder. -func (c Ed25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { - if len(c) == ed25519.PublicKeySize { - encoder.Write(c) - } else { - encoder.WriteEmptyBytes(ed25519.PublicKeySize) +// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort) } + if len(c) != ed25519.PublicKeySize { + return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil + } + return libolmpickle.PickleBytes(c, target), nil } -// UnpickleLibOlm unpickles the unencryted value and populates the public key -// accordingly. -func (c *Ed25519PublicKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - key, err := decoder.ReadBytes(ed25519.PublicKeySize) - *c = key - return err +// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. +func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { + unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize) + if err != nil { + return 0, err + } + *c = unpickled + return readBytes, nil +} + +// PickleLen returns the number of bytes the pickled public key will have. +func (c Ed25519PublicKey) PickleLen() int { + if len(c) != ed25519.PublicKeySize { + return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize)) + } + return libolmpickle.PickleBytesLen(c) } diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 610b8f3e..391de912 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -1,89 +1,140 @@ package crypto_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) -const ed25519KeyPairPickleLength = ed25519.PublicKeySize + // PublicKey - ed25519.PrivateKeySize // Private Key - func TestEd25519(t *testing.T) { - keypair, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) + keypair, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } message := []byte("test message") - signature, err := keypair.Sign(message) - require.NoError(t, err) - assert.True(t, keypair.Verify(message, signature)) + signature := keypair.Sign(message) + if !keypair.Verify(message, signature) { + t.Fail() + } } func TestEd25519Case1(t *testing.T) { //64 bytes for ed25519 package - keyPair, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) + keyPair, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } message := []byte("Hello, World") keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey) - assert.Equal(t, keyPair, keyPair2, "not equal key pairs") - signature, err := keyPair.Sign(message) - require.NoError(t, err) + if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) { + t.Fatal("not equal key pairs") + } + signature := keyPair.Sign(message) verified := keyPair.Verify(message, signature) - assert.True(t, verified, "message did not verify although it should") - + if !verified { + t.Fatal("message did not verify although it should") + } //Now change the message and verify again message = append(message, []byte("a")...) verified = keyPair.Verify(message, signature) - assert.False(t, verified, "message did verify although it should not") + if verified { + t.Fatal("message did verify although it should not") + } } func TestEd25519Pickle(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) - encoder := libolmpickle.NewEncoder() - keyPair.PickleLibOlm(encoder) - assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) + keyPair, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } unpickledKeyPair := crypto.Ed25519KeyPair{} - err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) - assert.NoError(t, err) - assert.Equal(t, keyPair, unpickledKeyPair) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } } func TestEd25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) + keyPair, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } //Remove privateKey keyPair.PrivateKey = nil - encoder := libolmpickle.NewEncoder() - keyPair.PickleLibOlm(encoder) - assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) - + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } unpickledKeyPair := crypto.Ed25519KeyPair{} - err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) - assert.NoError(t, err) - assert.Equal(t, keyPair, unpickledKeyPair) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } } func TestEd25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) + keyPair, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } //Remove public keyPair.PublicKey = nil - encoder := libolmpickle.NewEncoder() - keyPair.PickleLibOlm(encoder) - assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) - + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } unpickledKeyPair := crypto.Ed25519KeyPair{} - err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) - assert.NoError(t, err) - assert.Equal(t, keyPair, unpickledKeyPair) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } } diff --git a/crypto/goolm/crypto/hmac.go b/crypto/goolm/crypto/hmac.go new file mode 100644 index 00000000..8542f7cb --- /dev/null +++ b/crypto/goolm/crypto/hmac.go @@ -0,0 +1,29 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/sha256" + "io" + + "golang.org/x/crypto/hkdf" +) + +// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key. +func HMACSHA256(key, input []byte) []byte { + hash := hmac.New(sha256.New, key) + hash.Write(input) + return hash.Sum(nil) +} + +// SHA256 return the SHA-256 of the value. +func SHA256(value []byte) []byte { + hash := sha256.New() + hash.Write(value) + return hash.Sum(nil) +} + +// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil. +// The reader can be used to read an arbitary length of bytes which are based on all parameters. +func HKDFSHA256(input, salt, info []byte) io.Reader { + return hkdf.New(sha256.New, input, salt, info) +} diff --git a/crypto/goolm/crypto/hmac_test.go b/crypto/goolm/crypto/hmac_test.go new file mode 100644 index 00000000..95c0bfd5 --- /dev/null +++ b/crypto/goolm/crypto/hmac_test.go @@ -0,0 +1,114 @@ +package crypto_test + +import ( + "bytes" + "encoding/base64" + "io" + "testing" + + "maunium.net/go/mautrix/crypto/goolm/crypto" +) + +func TestHMACSha256(t *testing.T) { + key := []byte("test key") + message := []byte("test message") + hash := crypto.HMACSHA256(key, message) + if !bytes.Equal(hash, crypto.HMACSHA256(key, message)) { + t.Fail() + } + str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc" + result, err := base64.RawStdEncoding.DecodeString(str) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, hash) { + t.Fail() + } +} + +func TestHKDFSha256(t *testing.T) { + message := []byte("test content") + hkdf := crypto.HKDFSHA256(message, nil, nil) + hkdf2 := crypto.HKDFSHA256(message, nil, nil) + result := make([]byte, 32) + if _, err := io.ReadFull(hkdf, result); err != nil { + t.Fatal(err) + } + result2 := make([]byte, 32) + if _, err := io.ReadFull(hkdf2, result2); err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, result2) { + t.Fail() + } +} + +func TestSha256Case1(t *testing.T) { + input := make([]byte, 0) + expected := []byte{ + 0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14, + 0x9A, 0xFB, 0xF4, 0xC8, 0x99, 0x6F, 0xB9, 0x24, + 0x27, 0xAE, 0x41, 0xE4, 0x64, 0x9B, 0x93, 0x4C, + 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55, + } + result := crypto.SHA256(input) + if !bytes.Equal(expected, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) + } +} + +func TestHMACCase1(t *testing.T) { + input := make([]byte, 0) + expected := []byte{ + 0xb6, 0x13, 0x67, 0x9a, 0x08, 0x14, 0xd9, 0xec, + 0x77, 0x2f, 0x95, 0xd7, 0x78, 0xc3, 0x5f, 0xc5, + 0xff, 0x16, 0x97, 0xc4, 0x93, 0x71, 0x56, 0x53, + 0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad, + } + result := crypto.HMACSHA256(input, input) + if !bytes.Equal(expected, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) + } +} + +func TestHDKFCase1(t *testing.T) { + input := []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + } + salt := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, + } + info := []byte{ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, + } + expectedHMAC := []byte{ + 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, + 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63, + 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, + 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, + } + result := crypto.HMACSHA256(salt, input) + if !bytes.Equal(expectedHMAC, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC) + } + expectedHDKF := []byte{ + 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, + 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a, + 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, + 0x5d, 0xb0, 0x2d, 0x56, 0xec, 0xc4, 0xc5, 0xbf, + 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, + 0x58, 0x65, + } + resultReader := crypto.HKDFSHA256(input, salt, info) + result = make([]byte, len(expectedHDKF)) + if _, err := io.ReadFull(resultReader, result); err != nil { + t.Fatal(err) + } + if !bytes.Equal(expectedHDKF, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHDKF) + } +} diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index 888b1749..67465563 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -3,8 +3,11 @@ package crypto import ( "encoding/base64" "encoding/binary" + "fmt" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/id" ) // OneTimeKey stores the information about a one time key. @@ -15,32 +18,78 @@ type OneTimeKey struct { } // Equal compares the one time key to the given one. -func (otk OneTimeKey) Equal(other OneTimeKey) bool { - return otk.ID == other.ID && - otk.Published == other.Published && - otk.Key.PrivateKey.Equal(other.Key.PrivateKey) && - otk.Key.PublicKey.Equal(other.Key.PublicKey) -} - -// PickleLibOlm pickles the key pair into the encoder. -func (c OneTimeKey) PickleLibOlm(encoder *libolmpickle.Encoder) { - encoder.WriteUInt32(c.ID) - encoder.WriteBool(c.Published) - c.Key.PickleLibOlm(encoder) -} - -// UnpickleLibOlm unpickles the unencryted value and populates the [OneTimeKey] -// accordingly. -func (c *OneTimeKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { - if c.ID, err = decoder.ReadUInt32(); err != nil { - return - } else if c.Published, err = decoder.ReadBool(); err != nil { - return +func (otk OneTimeKey) Equal(s OneTimeKey) bool { + if otk.ID != s.ID { + return false } - return c.Key.UnpickleLibOlm(decoder) + if otk.Published != s.Published { + return false + } + if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) { + return false + } + if !otk.Key.PublicKey.Equal(s.Key.PublicKey) { + return false + } + return true } -// KeyIDEncoded returns the base64 encoded key ID. -func (c OneTimeKey) KeyIDEncoded() string { - return base64.RawStdEncoding.EncodeToString(binary.BigEndian.AppendUint32(nil, c.ID)) +// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort) + } + written := libolmpickle.PickleUInt32(uint32(c.ID), target) + written += libolmpickle.PickleBool(c.Published, target[written:]) + writtenKey, err := c.Key.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle one time key: %w", err) + } + written += writtenKey + return written, nil +} + +// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read. +func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) { + totalReadBytes := 0 + id, readBytes, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + totalReadBytes += readBytes + c.ID = id + published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:]) + if err != nil { + return 0, err + } + totalReadBytes += readBytes + c.Published = published + readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:]) + if err != nil { + return 0, err + } + totalReadBytes += readBytes + return totalReadBytes, nil +} + +// PickleLen returns the number of bytes the pickled OneTimeKey will have. +func (c OneTimeKey) PickleLen() int { + length := 0 + length += libolmpickle.PickleUInt32Len(c.ID) + length += libolmpickle.PickleBoolLen(c.Published) + length += c.Key.PickleLen() + return length +} + +// KeyIDEncoded returns the base64 encoded id. +func (c OneTimeKey) KeyIDEncoded() string { + resSlice := make([]byte, 4) + binary.BigEndian.PutUint32(resSlice, c.ID) + return base64.RawStdEncoding.EncodeToString(resSlice) +} + +// PublicKeyEncoded returns the base64 encoded public key +func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 { + return c.Key.PublicKey.B64Encoded() } diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go new file mode 100644 index 00000000..6539b0f1 --- /dev/null +++ b/crypto/goolm/errors.go @@ -0,0 +1,28 @@ +package goolm + +import ( + "errors" +) + +// Those are the most common used errors +var ( + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("bad mac") + ErrBadMessageFormat = errors.New("bad message format") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key") + ErrBadMessageKeyID = errors.New("bad message key id") + ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrBadVersion = errors.New("wrong version") + ErrWrongPickleVersion = errors.New("wrong pickle version") + ErrValueTooShort = errors.New("value too short") + ErrInputToSmall = errors.New("input too small (truncated?)") + ErrOverflow = errors.New("overflow") +) diff --git a/crypto/goolm/libolmpickle/encoder.go b/crypto/goolm/libolmpickle/encoder.go deleted file mode 100644 index 63e7b09b..00000000 --- a/crypto/goolm/libolmpickle/encoder.go +++ /dev/null @@ -1,40 +0,0 @@ -package libolmpickle - -import ( - "bytes" - "encoding/binary" - - "go.mau.fi/util/exerrors" -) - -const ( - PickleBoolLength = 1 - PickleUInt8Length = 1 - PickleUInt32Length = 4 -) - -type Encoder struct { - bytes.Buffer -} - -func NewEncoder() *Encoder { return &Encoder{} } - -func (p *Encoder) WriteUInt8(value uint8) { - exerrors.PanicIfNotNil(p.WriteByte(value)) -} - -func (p *Encoder) WriteBool(value bool) { - if value { - exerrors.PanicIfNotNil(p.WriteByte(0x01)) - } else { - exerrors.PanicIfNotNil(p.WriteByte(0x00)) - } -} - -func (p *Encoder) WriteEmptyBytes(count int) { - exerrors.Must(p.Write(make([]byte, count))) -} - -func (p *Encoder) WriteUInt32(value uint32) { - exerrors.PanicIfNotNil(binary.Write(&p.Buffer, binary.BigEndian, value)) -} diff --git a/crypto/goolm/libolmpickle/encoder_test.go b/crypto/goolm/libolmpickle/encoder_test.go deleted file mode 100644 index c7811225..00000000 --- a/crypto/goolm/libolmpickle/encoder_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package libolmpickle_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" -) - -func TestEncoder(t *testing.T) { - var encoder libolmpickle.Encoder - encoder.WriteUInt32(4) - encoder.WriteUInt8(8) - encoder.WriteBool(false) - encoder.WriteEmptyBytes(10) - encoder.WriteBool(true) - encoder.Write([]byte("test")) - encoder.WriteUInt32(420_000) - assert.Equal(t, []byte{ - 0x00, 0x00, 0x00, 0x04, // 4 - 0x08, // 8 - 0x00, // false - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes - 0x01, //true - 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) - 0x00, 0x06, 0x68, 0xa0, // 420,000 - }, encoder.Bytes()) -} - -func TestPickleUInt32(t *testing.T) { - values := []uint32{ - 0xffffffff, - 0x00ff00ff, - 0xf0000000, - 0xf00f0000, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - {0xf0, 0x0f, 0x00, 0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteUInt32(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleBool(t *testing.T) { - values := []bool{ - true, - false, - } - expected := [][]byte{ - {0x01}, - {0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteBool(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleUInt8(t *testing.T) { - values := []uint8{ - 0xff, - 0x1a, - } - expected := [][]byte{ - {0xff}, - {0x1a}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteUInt8(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleBytes(t *testing.T) { - values := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.Write(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index d15358fd..ec125a34 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -1,48 +1,41 @@ package libolmpickle import ( - "crypto/aes" - "fmt" - - "maunium.net/go/mautrix/crypto/goolm/aessha2" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" - "maunium.net/go/mautrix/crypto/olm" + "encoding/binary" ) -const pickleMACLength = 8 - -var kdfPickle = []byte("Pickle") //used to derive the keys for encryption - -// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. -func Pickle(key, plaintext []byte) ([]byte, error) { - if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { - return nil, err - } else if ciphertext, err := c.Encrypt(plaintext); err != nil { - return nil, err - } else if mac, err := c.MAC(ciphertext); err != nil { - return nil, err - } else { - return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil - } +func PickleUInt8(value uint8, target []byte) int { + target[0] = value + return 1 +} +func PickleUInt8Len(value uint8) int { + return 1 } -// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. -func Unpickle(key, input []byte) ([]byte, error) { - ciphertext, err := goolmbase64.Decode(input) - if err != nil { - return nil, err - } - ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] - if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { - return nil, err - } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { - return nil, err - } else if !verified { - return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) +func PickleBool(value bool, target []byte) int { + if value { + target[0] = 0x01 } else { - // Set to next block size - targetCipherText := make([]byte, int(len(ciphertext)/aes.BlockSize)*aes.BlockSize) - copy(targetCipherText, ciphertext) - return c.Decrypt(targetCipherText) + target[0] = 0x00 } + return 1 +} +func PickleBoolLen(value bool) int { + return 1 +} + +func PickleBytes(value, target []byte) int { + return copy(target, value) +} +func PickleBytesLen(value []byte) int { + return len(value) +} + +func PickleUInt32(value uint32, target []byte) int { + res := make([]byte, 4) //4 bytes for int32 + binary.BigEndian.PutUint32(res, value) + return copy(target, res) +} +func PickleUInt32Len(value uint32) int { + return 4 } diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index 0720e008..ce118428 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -1,26 +1,98 @@ -package libolmpickle +package libolmpickle_test import ( - "crypto/aes" + "bytes" "testing" - "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) -func TestEncoding(t *testing.T) { - key := []byte("test key") - input := []byte("test") - //pad marshaled to get block size - toEncrypt := input - if len(input)%aes.BlockSize != 0 { - padding := aes.BlockSize - len(input)%aes.BlockSize - toEncrypt = make([]byte, len(input)+padding) - copy(toEncrypt, input) +func TestPickleUInt32(t *testing.T) { + values := []uint32{ + 0xffffffff, + 0x00ff00ff, + 0xf0000000, + 0xf00f0000, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + {0xf0, 0x0f, 0x00, 0x00}, + } + for curIndex := range values { + response := make([]byte, 4) + resPLen := libolmpickle.PickleUInt32(values[curIndex], response) + if resPLen != libolmpickle.PickleUInt32Len(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestPickleBool(t *testing.T) { + values := []bool{ + true, + false, + } + expected := [][]byte{ + {0x01}, + {0x00}, + } + for curIndex := range values { + response := make([]byte, 1) + resPLen := libolmpickle.PickleBool(values[curIndex], response) + if resPLen != libolmpickle.PickleBoolLen(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestPickleUInt8(t *testing.T) { + values := []uint8{ + 0xff, + 0x1a, + } + expected := [][]byte{ + {0xff}, + {0x1a}, + } + for curIndex := range values { + response := make([]byte, 1) + resPLen := libolmpickle.PickleUInt8(values[curIndex], response) + if resPLen != libolmpickle.PickleUInt8Len(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestPickleBytes(t *testing.T) { + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for curIndex := range values { + response := make([]byte, len(values[curIndex])) + resPLen := libolmpickle.PickleBytes(values[curIndex], response) + if resPLen != libolmpickle.PickleBytesLen(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } } - encoded, err := Pickle(key, toEncrypt) - assert.NoError(t, err) - - decoded, err := Unpickle(key, encoded) - assert.NoError(t, err) - assert.Equal(t, toEncrypt, decoded) } diff --git a/crypto/goolm/libolmpickle/unpickle.go b/crypto/goolm/libolmpickle/unpickle.go index d13be315..9a6a4b62 100644 --- a/crypto/goolm/libolmpickle/unpickle.go +++ b/crypto/goolm/libolmpickle/unpickle.go @@ -1,52 +1,53 @@ package libolmpickle import ( - "bytes" - "encoding/binary" "fmt" + + "maunium.net/go/mautrix/crypto/goolm" ) -func isZeroByteSlice(data []byte) bool { - for _, b := range data { - if b != 0 { - return false - } +func isZeroByteSlice(bytes []byte) bool { + b := byte(0) + for _, s := range bytes { + b |= s } - return true + return b == 0 } -type Decoder struct { - buf bytes.Buffer -} - -func NewDecoder(buf []byte) *Decoder { - return &Decoder{buf: *bytes.NewBuffer(buf)} -} - -func (d *Decoder) ReadUInt8() (uint8, error) { - return d.buf.ReadByte() -} - -func (d *Decoder) ReadBool() (bool, error) { - val, err := d.buf.ReadByte() - return val != 0x00, err -} - -func (d *Decoder) ReadBytes(length int) (data []byte, err error) { - data = d.buf.Next(length) - if len(data) != length { - return nil, fmt.Errorf("only %d in buffer, expected %d", len(data), length) - } else if isZeroByteSlice(data) { - return nil, nil +func UnpickleUInt8(value []byte) (uint8, int, error) { + if len(value) < 1 { + return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort) } - return + return value[0], 1, nil } -func (d *Decoder) ReadUInt32() (uint32, error) { - data := d.buf.Next(4) - if len(data) != 4 { - return 0, fmt.Errorf("only %d bytes is buffer, expected 4 for uint32", len(data)) - } else { - return binary.BigEndian.Uint32(data), nil +func UnpickleBool(value []byte) (bool, int, error) { + if len(value) < 1 { + return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort) } + return value[0] != uint8(0x00), 1, nil +} + +func UnpickleBytes(value []byte, length int) ([]byte, int, error) { + if len(value) < length { + return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort) + } + resp := value[:length] + if isZeroByteSlice(resp) { + return nil, length, nil + } + return resp, length, nil +} + +func UnpickleUInt32(value []byte) (uint32, int, error) { + if len(value) < 4 { + return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort) + } + var res uint32 + count := 0 + for i := 3; i >= 0; i-- { + res |= uint32(value[count]) << (8 * i) + count++ + } + return res, 4, nil } diff --git a/crypto/goolm/libolmpickle/unpickle_test.go b/crypto/goolm/libolmpickle/unpickle_test.go index 30355a76..937630e5 100644 --- a/crypto/goolm/libolmpickle/unpickle_test.go +++ b/crypto/goolm/libolmpickle/unpickle_test.go @@ -1,10 +1,9 @@ package libolmpickle_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) @@ -20,10 +19,16 @@ func TestUnpickleUInt32(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - decoder := libolmpickle.NewDecoder(values[curIndex]) - response, err := decoder.ReadUInt32() - assert.NoError(t, err) - assert.Equal(t, expected[curIndex], response) + response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex]) + if err != nil { + t.Fatal(err) + } + if readLength != 4 { + t.Fatal("read bytes not correct") + } + if response != expected[curIndex] { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } } } @@ -39,10 +44,16 @@ func TestUnpickleBool(t *testing.T) { {0x02}, } for curIndex := range values { - decoder := libolmpickle.NewDecoder(values[curIndex]) - response, err := decoder.ReadBool() - assert.NoError(t, err) - assert.Equal(t, expected[curIndex], response) + response, readLength, err := libolmpickle.UnpickleBool(values[curIndex]) + if err != nil { + t.Fatal(err) + } + if readLength != 1 { + t.Fatal("read bytes not correct") + } + if response != expected[curIndex] { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } } } @@ -56,10 +67,16 @@ func TestUnpickleUInt8(t *testing.T) { {0x1a}, } for curIndex := range values { - decoder := libolmpickle.NewDecoder(values[curIndex]) - response, err := decoder.ReadUInt8() - assert.NoError(t, err) - assert.Equal(t, expected[curIndex], response) + response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex]) + if err != nil { + t.Fatal(err) + } + if readLength != 1 { + t.Fatal("read bytes not correct") + } + if response != expected[curIndex] { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } } } @@ -75,9 +92,15 @@ func TestUnpickleBytes(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - decoder := libolmpickle.NewDecoder(values[curIndex]) - response, err := decoder.ReadBytes(4) - assert.NoError(t, err) - assert.Equal(t, expected[curIndex], response) + response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4) + if err != nil { + t.Fatal(err) + } + if readLength != 4 { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } } } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 3b5f1e4a..c3493f7b 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -2,17 +2,15 @@ package megolm import ( - "crypto/hmac" "crypto/rand" - "crypto/sha256" "fmt" - "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -25,7 +23,7 @@ const ( RatchetPartLength = 256 / 8 // length of each ratchet part in bytes ) -var megolmKeysKDFInfo = []byte("MEGOLM_KEYS") +var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) // hasKeySeed are the seed for the different ratchet parts var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{ @@ -64,9 +62,8 @@ func NewWithRandom(counter uint32) (*Ratchet, error) { // rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to. func (m *Ratchet) rehashPart(from, to int) { - hash := hmac.New(sha256.New, m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength]) - hash.Write(hashKeySeeds[to]) - copy(m.Data[to*RatchetPartLength:], hash.Sum(nil)) + newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to]) + copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength]) } // Advance advances the ratchet one step. @@ -135,8 +132,9 @@ func (m *Ratchet) AdvanceTo(target uint32) { // Encrypt encrypts the message in a message.GroupMessage with MAC and signature. // The output is base64 encoded. -func (r *Ratchet) Encrypt(plaintext []byte, key crypto.Ed25519KeyPair) ([]byte, error) { - cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) +func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) { + var err error + encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -144,12 +142,9 @@ func (r *Ratchet) Encrypt(plaintext []byte, key crypto.Ed25519KeyPair) ([]byte, message := &message.GroupMessage{} message.Version = protocolVersion message.MessageIndex = r.Counter - message.Ciphertext, err = cipher.Encrypt(plaintext) - if err != nil { - return nil, err - } - //creating the MAC and signing is done in encode - output, err := message.EncodeAndMACAndSign(cipher, key) + message.Ciphertext = encryptedText + //creating the mac and signing is done in encode + output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key) if err != nil { return nil, err } @@ -162,8 +157,8 @@ func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error m := message.MegolmSessionSharing{} m.Counter = r.Counter m.RatchetData = r.Data - encoded, err := m.EncodeAndSign(key) - return goolmbase64.Encode(encoded), err + encoded := m.EncodeAndSign(key) + return goolm.Base64Encode(encoded), nil } // SessionExportMessage creates a message in the session export format. @@ -173,51 +168,67 @@ func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, erro m.RatchetData = r.Data m.PublicKey = key encoded := m.Encode() - return goolmbase64.Encode(encoded), nil + return goolm.Base64Encode(encoded), nil } // Decrypt decrypts the ciphertext and verifies the MAC but not the signature. func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) { //verify mac - cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) - if err != nil { - return nil, err - } - verifiedMAC, err := msg.VerifyMACInline(cipher, ciphertext) + verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext) if err != nil { return nil, err } if !verifiedMAC { - return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) + return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) } - return cipher.Decrypt(msg.Ciphertext) + return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(r, megolmPickleVersion, key) + return utilities.PickleAsJSON(r, megolmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) + return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) } // UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. -func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - ratchetData, err := decoder.ReadBytes(RatchetParts * RatchetPartLength) +func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { + //read ratchet data + curPos := 0 + ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength) if err != nil { - return err + return 0, err } copy(r.Data[:], ratchetData) - - r.Counter, err = decoder.ReadUInt32() - return err + curPos += readBytes + //Read counter + counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + r.Counter = counter + return curPos, nil } -// PickleLibOlm pickles the ratchet into the encoder. -func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) { - encoder.Write(r.Data[:]) - encoder.WriteUInt32(r.Counter) +// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r Ratchet) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) + } + written := libolmpickle.PickleBytes(r.Data[:], target) + written += libolmpickle.PickleUInt32(r.Counter, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled ratchet will have. +func (r Ratchet) PickleLen() int { + length := libolmpickle.PickleBytesLen(r.Data[:]) + length += libolmpickle.PickleUInt32Len(r.Counter) + return length } diff --git a/crypto/goolm/megolm/megolm_test.go b/crypto/goolm/megolm/megolm_test.go index a6f7c1a7..40289eaf 100644 --- a/crypto/goolm/megolm/megolm_test.go +++ b/crypto/goolm/megolm/megolm_test.go @@ -1,10 +1,9 @@ package megolm_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - "maunium.net/go/mautrix/crypto/goolm/megolm" ) @@ -20,7 +19,9 @@ func init() { func TestAdvance(t *testing.T) { m, err := megolm.New(0, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } expectedData := [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, @@ -33,7 +34,9 @@ func TestAdvance(t *testing.T) { 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, } m.Advance() - assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") + if !bytes.Equal(m.Data[:], expectedData[:]) { + t.Fatal("result after advancing the ratchet is not as expected") + } //repeat with complex advance m.Data = startData @@ -48,8 +51,9 @@ func TestAdvance(t *testing.T) { 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, } m.AdvanceTo(0x1000000) - assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") - + if !bytes.Equal(m.Data[:], expectedData[:]) { + t.Fatal("result after advancing the ratchet is not as expected") + } expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, @@ -61,45 +65,77 @@ func TestAdvance(t *testing.T) { 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a, } m.AdvanceTo(0x1041506) - assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") + if !bytes.Equal(m.Data[:], expectedData[:]) { + t.Fatal("result after advancing the ratchet is not as expected") + } } func TestAdvanceWraparound(t *testing.T) { m, err := megolm.New(0xffffffff, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } m.AdvanceTo(0x1000000) - assert.EqualValues(t, 0x1000000, m.Counter, "counter not correct") + if m.Counter != 0x1000000 { + t.Fatal("counter not correct") + } m2, err := megolm.New(0, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } m2.AdvanceTo(0x2000000) - assert.EqualValues(t, 0x2000000, m2.Counter, "counter not correct") - assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") + if m2.Counter != 0x2000000 { + t.Fatal("counter not correct") + } + if !bytes.Equal(m.Data[:], m2.Data[:]) { + t.Fatal("result after wrapping the ratchet is not as expected") + } } func TestAdvanceOverflowByOne(t *testing.T) { m, err := megolm.New(0xffffffff, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } m.AdvanceTo(0x0) - assert.EqualValues(t, 0x0, m.Counter, "counter not correct") + if m.Counter != 0x0 { + t.Fatal("counter not correct") + } m2, err := megolm.New(0xffffffff, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } m2.Advance() - assert.EqualValues(t, 0x0, m2.Counter, "counter not correct") - assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") + if m2.Counter != 0x0 { + t.Fatal("counter not correct") + } + if !bytes.Equal(m.Data[:], m2.Data[:]) { + t.Fatal("result after wrapping the ratchet is not as expected") + } } func TestAdvanceOverflow(t *testing.T) { m, err := megolm.New(0x1, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } m.AdvanceTo(0x80000000) m.AdvanceTo(0x0) - assert.EqualValues(t, 0x0, m.Counter, "counter not correct") + if m.Counter != 0x0 { + t.Fatal("counter not correct") + } m2, err := megolm.New(0x1, startData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } m2.AdvanceTo(0x0) - assert.EqualValues(t, 0x0, m2.Counter, "counter not correct") - assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") + if m2.Counter != 0x0 { + t.Fatal("counter not correct") + } + if !bytes.Equal(m.Data[:], m2.Data[:]) { + t.Fatal("result after wrapping the ratchet is not as expected") + } } diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index b06756a9..ba49f011 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -1,33 +1,70 @@ package message import ( - "bytes" "encoding/binary" - "fmt" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm" ) -type Decoder struct { - *bytes.Buffer -} - -func NewDecoder(buf []byte) *Decoder { - return &Decoder{bytes.NewBuffer(buf)} -} - -func (d *Decoder) ReadVarInt() (uint64, error) { - return binary.ReadUvarint(d) -} - -func (d *Decoder) ReadVarBytes() ([]byte, error) { - if n, err := d.ReadVarInt(); err != nil { - return nil, err - } else if n > uint64(d.Len()) { - return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available()) - } else { - out := make([]byte, n) - _, err = d.Read(out) - return out, err +// checkDecodeErr checks if there was an error during decode. +func checkDecodeErr(readBytes int) error { + if readBytes == 0 { + //end reached + return goolm.ErrInputToSmall } + if readBytes < 0 { + return goolm.ErrOverflow + } + return nil +} + +// decodeVarInt decodes a single big-endian encoded varint. +func decodeVarInt(input []byte) (uint32, int) { + value, readBytes := binary.Uvarint(input) + return uint32(value), readBytes +} + +// decodeVarString decodes the length of the string (varint) and returns the actual string +func decodeVarString(input []byte) ([]byte, int) { + stringLen, readBytes := decodeVarInt(input) + if readBytes <= 0 { + return nil, readBytes + } + input = input[readBytes:] + value := input[:stringLen] + readBytes += int(stringLen) + return value, readBytes +} + +// encodeVarIntByteLength returns the number of bytes needed to encode the uint32. +func encodeVarIntByteLength(input uint32) int { + result := 1 + for input >= 128 { + result++ + input >>= 7 + } + return result +} + +// encodeVarStringByteLength returns the number of bytes needed to encode the input. +func encodeVarStringByteLength(input []byte) int { + result := encodeVarIntByteLength(uint32(len(input))) + result += len(input) + return result +} + +// encodeVarInt encodes a single uint32 +func encodeVarInt(input uint32) []byte { + out := make([]byte, encodeVarIntByteLength(input)) + binary.PutUvarint(out, uint64(input)) + return out +} + +// encodeVarString encodes the length of the input (varint) and appends the actual input +func encodeVarString(input []byte) []byte { + out := make([]byte, encodeVarStringByteLength(input)) + length := encodeVarInt(uint32(len(input))) + copy(out, length) + copy(out[len(length):], input) + return out } diff --git a/crypto/goolm/message/encoder_test.go b/crypto/goolm/message/decoder_test.go similarity index 50% rename from crypto/goolm/message/encoder_test.go rename to crypto/goolm/message/decoder_test.go index 1fe2ebdb..39503e3e 100644 --- a/crypto/goolm/message/encoder_test.go +++ b/crypto/goolm/message/decoder_test.go @@ -1,13 +1,36 @@ -package message_test +package message import ( + "bytes" "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/message" ) +func TestEncodeLengthInt(t *testing.T) { + numbers := []uint32{127, 128, 16383, 16384, 32767} + expected := []int{1, 2, 2, 3, 3} + for curIndex := range numbers { + if result := encodeVarIntByteLength(numbers[curIndex]); result != expected[curIndex] { + t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) + } + } +} + +func TestEncodeLengthString(t *testing.T) { + var strings [][]byte + var expected []int + strings = append(strings, []byte("test")) + expected = append(expected, 1+4) + strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------")) + expected = append(expected, 1+127) + strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) + expected = append(expected, 2+155) + for curIndex := range strings { + if result := encodeVarStringByteLength(strings[curIndex]); result != expected[curIndex] { + t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) + } + } +} + func TestEncodeInt(t *testing.T) { var ints []uint32 var expected [][]byte @@ -20,9 +43,9 @@ func TestEncodeInt(t *testing.T) { ints = append(ints, 16383) expected = append(expected, []byte{0b11111111, 0b01111111}) for curIndex := range ints { - var encoder message.Encoder - encoder.PutVarInt(uint64(ints[curIndex])) - assert.Equal(t, expected[curIndex], encoder.Bytes()) + if result := encodeVarInt(ints[curIndex]); !bytes.Equal(result, expected[curIndex]) { + t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) + } } } @@ -52,8 +75,8 @@ func TestEncodeString(t *testing.T) { res = append(res, curTest...) //Add string itself expected = append(expected, res) for curIndex := range strings { - var encoder message.Encoder - encoder.PutVarBytes(strings[curIndex]) - assert.Equal(t, expected[curIndex], encoder.Bytes()) + if result := encodeVarString(strings[curIndex]); !bytes.Equal(result, expected[curIndex]) { + t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) + } } } diff --git a/crypto/goolm/message/encoder.go b/crypto/goolm/message/encoder.go deleted file mode 100644 index 95ab6d41..00000000 --- a/crypto/goolm/message/encoder.go +++ /dev/null @@ -1,24 +0,0 @@ -package message - -import "encoding/binary" - -type Encoder struct { - buf []byte -} - -func (e *Encoder) Bytes() []byte { - return e.buf -} - -func (e *Encoder) PutByte(val byte) { - e.buf = append(e.buf, val) -} - -func (e *Encoder) PutVarInt(val uint64) { - e.buf = binary.AppendUvarint(e.buf, val) -} - -func (e *Encoder) PutVarBytes(data []byte) { - e.PutVarInt(uint64(len(data))) - e.buf = append(e.buf, data...) -} diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index c83540c1..ebd5b77e 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,12 +2,9 @@ package message import ( "bytes" - "fmt" - "io" - "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -25,87 +22,112 @@ type GroupMessage struct { } // Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. -func (r *GroupMessage) Decode(input []byte) (err error) { +func (r *GroupMessage) Decode(input []byte) error { r.Version = 0 r.MessageIndex = 0 r.Ciphertext = nil if len(input) == 0 { return nil } - - decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize]) - r.Version, err = decoder.ReadByte() // First byte is the version - if err != nil { - return - } - if r.Version != protocolVersion { - return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) - } - - for { - // Read Key - if curKey, err := decoder.ReadVarInt(); err != nil { - if err == io.EOF { - // No more keys to read - return nil - } + //first Byte is always version + r.Version = input[0] + curPos := 1 + for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize { + //Read Key + curKey, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if (curKey & 0b111) == 0 { - // The value is of type varint - if value, err := decoder.ReadVarInt(); err != nil { + } + curPos += readBytes + if (curKey & 0b111) == 0 { + //The value is of type varint + value, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if curKey == messageIndexTag { - r.MessageIndex = uint32(value) + } + curPos += readBytes + switch curKey { + case messageIndexTag: + r.MessageIndex = value r.HasMessageIndex = true } } else if (curKey & 0b111) == 2 { - // The value is of type string - if value, err := decoder.ReadVarBytes(); err != nil { + //The value is of type string + value, readBytes := decodeVarString(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if curKey == cipherTextTag { + } + curPos += readBytes + switch curKey { + case cipherTextTag: r.Ciphertext = value } } } + + return nil } -// EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message. +// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message. // If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. -func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) { - var encoder Encoder - encoder.PutByte(r.Version) - encoder.PutVarInt(messageIndexTag) - encoder.PutVarInt(uint64(r.MessageIndex)) - encoder.PutVarInt(cipherTextTag) - encoder.PutVarBytes(r.Ciphertext) - mac, err := r.MAC(cipher, encoder.Bytes()) - if err != nil { - return nil, err +func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) { + var lengthOfMessage int + lengthOfMessage += 1 //Version + lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) + lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext) + out := make([]byte, lengthOfMessage) + out[0] = r.Version + curPos := 1 + encodedTag := encodeVarInt(messageIndexTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue := encodeVarInt(r.MessageIndex) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(cipherTextTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.Ciphertext) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + if len(macKey) != 0 && cipher != nil { + mac, err := r.MAC(macKey, cipher, out) + if err != nil { + return nil, err + } + out = append(out, mac[:countMACBytesGroupMessage]...) } - ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...) - signature, err := signKey.Sign(ciphertextWithMAC) - return append(ciphertextWithMAC, signature...), err + if signKey != nil { + signature := signKey.Sign(out) + out = append(out, signature...) + } + return out, nil } // MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. -func (r *GroupMessage) MAC(cipher aessha2.AESSHA2, ciphertext []byte) ([]byte, error) { - mac, err := cipher.MAC(ciphertext) +func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) { + mac, err := cipher.MAC(key, message) if err != nil { return nil, err } return mac[:countMACBytesGroupMessage], nil } +// VerifySignature verifies the givenSignature to the calculated signature of the message. +func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool { + return key.Verify(message, givenSignature) +} + // VerifySignature verifies the signature taken from the message to the calculated signature of the message. func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool { - signature := message[len(message)-crypto.Ed25519SignatureSize:] - message = message[:len(message)-crypto.Ed25519SignatureSize] + signature := message[len(message)-crypto.ED25519SignatureSize:] + message = message[:len(message)-crypto.ED25519SignatureSize] return key.Verify(message, signature) } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *GroupMessage) VerifyMAC(cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { - checkMac, err := r.MAC(cipher, ciphertext) +func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { + checkMac, err := r.MAC(key, cipher, message) if err != nil { return false, err } @@ -113,10 +135,10 @@ func (r *GroupMessage) VerifyMAC(cipher aessha2.AESSHA2, ciphertext, givenMAC [] } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *GroupMessage) VerifyMACInline(cipher aessha2.AESSHA2, message []byte) (bool, error) { - startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize +func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { + startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize endMAC := startMAC + countMACBytesGroupMessage suplMac := message[startMAC:endMAC] message = message[:startMAC] - return r.VerifyMAC(cipher, message, suplMac) + return r.VerifyMAC(key, cipher, message, suplMac) } diff --git a/crypto/goolm/message/group_message_test.go b/crypto/goolm/message/group_message_test.go index 272138c4..4ae1f830 100644 --- a/crypto/goolm/message/group_message_test.go +++ b/crypto/goolm/message/group_message_test.go @@ -1,13 +1,9 @@ package message_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/goolm/aessha2" - "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -20,13 +16,22 @@ func TestGroupMessageDecode(t *testing.T) { msg := message.GroupMessage{} err := msg.Decode(messageRaw) - assert.NoError(t, err) - assert.EqualValues(t, 3, msg.Version) - assert.Equal(t, expectedMessageIndex, msg.MessageIndex) - assert.Equal(t, expectedCipherText, msg.Ciphertext) + if err != nil { + t.Fatal(err) + } + if msg.Version != 3 { + t.Fatalf("Expected Version to be 3 but go %d", msg.Version) + } + if msg.MessageIndex != expectedMessageIndex { + t.Fatalf("Expected message index to be %d but got %d", expectedMessageIndex, msg.MessageIndex) + } + if !bytes.Equal(msg.Ciphertext, expectedCipherText) { + t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) + } } func TestGroupMessageEncode(t *testing.T) { + expectedRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2signature") hmacsha256 := []byte("hmacsha2") sign := []byte("signature") msg := message.GroupMessage{ @@ -34,29 +39,13 @@ func TestGroupMessageEncode(t *testing.T) { MessageIndex: 200, Ciphertext: []byte("ciphertext"), } - - cipher, err := aessha2.NewAESSHA2(nil, nil) - require.NoError(t, err) - encoded, err := msg.EncodeAndMACAndSign(cipher, crypto.Ed25519GenerateFromSeed(make([]byte, 32))) - assert.NoError(t, err) + encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil) + if err != nil { + t.Fatal(err) + } encoded = append(encoded, hmacsha256...) encoded = append(encoded, sign...) - expected := []byte{ - 0x03, // Version - 0x08, - 0xC8, // 200 - 0x01, - 0x12, - 0x0a, + if !bytes.Equal(encoded, expectedRaw) { + t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) } - expected = append(expected, []byte("ciphertext")...) - expected = append(expected, []byte{ - 0x6f, 0x95, 0x35, 0x51, 0xdc, 0xdb, 0xcb, 0x03, 0x0b, 0x22, 0xa2, 0xa7, 0xa1, 0xb7, 0x4f, 0x1a, - 0xa3, 0xe9, 0x5c, 0x05, 0x5d, 0x56, 0xdc, 0x5b, 0x87, 0x73, 0x05, 0x42, 0x2a, 0x59, 0x9a, 0x9a, - 0x26, 0x7a, 0x8d, 0xba, 0x65, 0xb2, 0x17, 0x65, 0x51, 0x6f, 0x37, 0xf3, 0x8f, 0xa1, 0x70, 0xd0, - 0xc4, 0x06, 0x05, 0xdc, 0x17, 0x71, 0x5e, 0x63, 0x84, 0xbe, 0xec, 0x7b, 0xa0, 0xc4, 0x08, 0xb8, - 0x9b, 0xc5, 0x08, 0x16, 0xad, 0xe5, 0x43, 0x0c, - }...) - expected = append(expected, []byte("hmacsha2signature")...) - assert.Equal(t, expected, encoded) } diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index b161a2d1..8b721aeb 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,12 +2,9 @@ package message import ( "bytes" - "fmt" - "io" - "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -27,7 +24,7 @@ type Message struct { } // Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present. -func (r *Message) Decode(input []byte) (err error) { +func (r *Message) Decode(input []byte) error { r.Version = 0 r.HasCounter = false r.Counter = 0 @@ -36,63 +33,89 @@ func (r *Message) Decode(input []byte) (err error) { if len(input) == 0 { return nil } - - decoder := NewDecoder(input[:len(input)-countMACBytesMessage]) - r.Version, err = decoder.ReadByte() // first byte is always version - if err != nil { - return - } - if r.Version != protocolVersion { - return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) - } - - for { - // Read Key - if curKey, err := decoder.ReadVarInt(); err != nil { - if err == io.EOF { - // No more keys to read - return nil - } + //first Byte is always version + r.Version = input[0] + curPos := 1 + for curPos < len(input)-countMACBytesMessage { + //Read Key + curKey, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if (curKey & 0b111) == 0 { - // The value is of type varint - if value, err := decoder.ReadVarInt(); err != nil { + } + curPos += readBytes + if (curKey & 0b111) == 0 { + //The value is of type varint + value, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if curKey == counterTag { - r.Counter = uint32(value) + } + curPos += readBytes + switch curKey { + case counterTag: r.HasCounter = true + r.Counter = value } } else if (curKey & 0b111) == 2 { - // The value is of type string - if value, err := decoder.ReadVarBytes(); err != nil { + //The value is of type string + value, readBytes := decodeVarString(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if curKey == ratchetKeyTag { + } + curPos += readBytes + switch curKey { + case ratchetKeyTag: r.RatchetKey = value - } else if curKey == cipherTextKeyTag { + case cipherTextKeyTag: r.Ciphertext = value } } } + + return nil } // EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. // If key or cipher is nil, no MAC is appended. -func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) { - var encoder Encoder - encoder.PutByte(r.Version) - encoder.PutVarInt(ratchetKeyTag) - encoder.PutVarBytes(r.RatchetKey) - encoder.PutVarInt(counterTag) - encoder.PutVarInt(uint64(r.Counter)) - encoder.PutVarInt(cipherTextKeyTag) - encoder.PutVarBytes(r.Ciphertext) - mac, err := cipher.MAC(encoder.Bytes()) - return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err +func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) { + var lengthOfMessage int + lengthOfMessage += 1 //Version + lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) + lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter) + lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext) + out := make([]byte, lengthOfMessage) + out[0] = r.Version + curPos := 1 + encodedTag := encodeVarInt(ratchetKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue := encodeVarString(r.RatchetKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(counterTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarInt(r.Counter) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(cipherTextKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.Ciphertext) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + if len(key) != 0 && cipher != nil { + mac, err := cipher.MAC(key, out) + if err != nil { + return nil, err + } + out = append(out, mac[:countMACBytesMessage]...) + } + return out, nil } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *Message) VerifyMAC(key []byte, cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { - checkMAC, err := cipher.MAC(ciphertext) +func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { + checkMAC, err := cipher.MAC(key, message) if err != nil { return false, err } @@ -100,7 +123,7 @@ func (r *Message) VerifyMAC(key []byte, cipher aessha2.AESSHA2, ciphertext, give } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *Message) VerifyMACInline(key []byte, cipher aessha2.AESSHA2, message []byte) (bool, error) { +func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { givenMAC := message[len(message)-countMACBytesMessage:] return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC) } diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go index f3aa7108..4a9f29fb 100644 --- a/crypto/goolm/message/message_test.go +++ b/crypto/goolm/message/message_test.go @@ -1,11 +1,9 @@ package message_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -16,16 +14,28 @@ func TestMessageDecode(t *testing.T) { msg := message.Message{} err := msg.Decode(messageRaw) - assert.NoError(t, err) - assert.EqualValues(t, 3, msg.Version) - assert.True(t, msg.HasCounter) - assert.EqualValues(t, 1, msg.Counter) - assert.Equal(t, expectedCipherText, msg.Ciphertext) - assert.EqualValues(t, expectedRatchetKey, msg.RatchetKey) + if err != nil { + t.Fatal(err) + } + if msg.Version != 3 { + t.Fatalf("Expected Version to be 3 but go %d", msg.Version) + } + if !msg.HasCounter { + t.Fatal("Expected to have counter") + } + if msg.Counter != 1 { + t.Fatalf("Expected counter to be 1 but got %d", msg.Counter) + } + if !bytes.Equal(msg.Ciphertext, expectedCipherText) { + t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) + } + if !bytes.Equal(msg.RatchetKey, expectedRatchetKey) { + t.Fatalf("expected '%s' but got '%s'", expectedRatchetKey, msg.RatchetKey) + } } func TestMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertext\x95\x95\x92\x72\x04\x70\x56\xcdhmacsha2") + expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2") hmacsha256 := []byte("hmacsha2") msg := message.Message{ Version: 3, @@ -33,10 +43,12 @@ func TestMessageEncode(t *testing.T) { RatchetKey: []byte("ratchetkey"), Ciphertext: []byte("ciphertext"), } - cipher, err := aessha2.NewAESSHA2(nil, nil) - assert.NoError(t, err) - encoded, err := msg.EncodeAndMAC(cipher) - assert.NoError(t, err) + encoded, err := msg.EncodeAndMAC(nil, nil) + if err != nil { + t.Fatal(err) + } encoded = append(encoded, hmacsha256...) - assert.Equal(t, expectedRaw, encoded) + if !bytes.Equal(encoded, expectedRaw) { + t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) + } } diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 4e3d495d..6e007e06 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,15 +1,11 @@ package message import ( - "fmt" - "io" - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( - oneTimeKeyIDTag = 0x0A + oneTimeKeyIdTag = 0x0A baseKeyTag = 0x12 identityKeyTag = 0x1A messageTag = 0x22 @@ -23,13 +19,8 @@ type PreKeyMessage struct { Message []byte `json:"message"` } -// TODO deduplicate constant with one in session/olm_session.go -const ( - protocolVersion = 0x3 -) - // Decodes decodes the input and populates the corresponding fileds. -func (r *PreKeyMessage) Decode(input []byte) (err error) { +func (r *PreKeyMessage) Decode(input []byte) error { r.Version = 0 r.IdentityKey = nil r.BaseKey = nil @@ -38,55 +29,44 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) { if len(input) == 0 { return nil } - - decoder := NewDecoder(input) - r.Version, err = decoder.ReadByte() // first byte is always version - if err != nil { - if err == io.EOF { - return olm.ErrInputToSmall - } - return - } - if r.Version != protocolVersion { - return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) - } - - for { - // Read Key - if curKey, err := decoder.ReadVarInt(); err != nil { - if err == io.EOF { - return nil - } + //first Byte is always version + r.Version = input[0] + curPos := 1 + for curPos < len(input) { + //Read Key + curKey, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else if (curKey & 0b111) == 0 { - // The value is of type varint - if _, err = decoder.ReadVarInt(); err != nil { - if err == io.EOF { - return olm.ErrInputToSmall - } + } + curPos += readBytes + if (curKey & 0b111) == 0 { + //The value is of type varint + _, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err } + curPos += readBytes } else if (curKey & 0b111) == 2 { - // The value is of type string - if value, err := decoder.ReadVarBytes(); err != nil { - if err == io.EOF { - return olm.ErrInputToSmall - } + //The value is of type string + value, readBytes := decodeVarString(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { return err - } else { - switch curKey { - case oneTimeKeyIDTag: - r.OneTimeKey = value - case baseKeyTag: - r.BaseKey = value - case identityKeyTag: - r.IdentityKey = value - case messageTag: - r.Message = value - } + } + curPos += readBytes + switch curKey { + case oneTimeKeyIdTag: + r.OneTimeKey = value + case baseKeyTag: + r.BaseKey = value + case identityKeyTag: + r.IdentityKey = value + case messageTag: + r.Message = value } } } + + return nil } // CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message. @@ -94,25 +74,47 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey ok := true ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil) if r.IdentityKey != nil { - ok = ok && (len(r.IdentityKey) == crypto.Curve25519PrivateKeyLength) + ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength) } ok = ok && len(r.Message) != 0 - ok = ok && len(r.BaseKey) == crypto.Curve25519PrivateKeyLength - ok = ok && len(r.OneTimeKey) == crypto.Curve25519PrivateKeyLength + ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength + ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength return ok } // Encode encodes the message. func (r *PreKeyMessage) Encode() ([]byte, error) { - var encoder Encoder - encoder.PutByte(r.Version) - encoder.PutVarInt(oneTimeKeyIDTag) - encoder.PutVarBytes(r.OneTimeKey) - encoder.PutVarInt(identityKeyTag) - encoder.PutVarBytes(r.IdentityKey) - encoder.PutVarInt(baseKeyTag) - encoder.PutVarBytes(r.BaseKey) - encoder.PutVarInt(messageTag) - encoder.PutVarBytes(r.Message) - return encoder.Bytes(), nil + var lengthOfMessage int + lengthOfMessage += 1 //Version + lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey) + lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey) + lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey) + lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message) + out := make([]byte, lengthOfMessage) + out[0] = r.Version + curPos := 1 + encodedTag := encodeVarInt(oneTimeKeyIdTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue := encodeVarString(r.OneTimeKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(identityKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.IdentityKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(baseKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.BaseKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(messageTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.Message) + copy(out[curPos:], encodedValue) + return out, nil } diff --git a/crypto/goolm/message/prekey_message_test.go b/crypto/goolm/message/prekey_message_test.go index fe196e31..431d27d5 100644 --- a/crypto/goolm/message/prekey_message_test.go +++ b/crypto/goolm/message/prekey_message_test.go @@ -1,10 +1,9 @@ package message_test import ( + "bytes" "testing" - "github.com/stretchr/testify/assert" - "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -20,14 +19,29 @@ func TestPreKeyMessageDecode(t *testing.T) { msg := message.PreKeyMessage{} err := msg.Decode(messageRaw) - assert.NoError(t, err) - assert.EqualValues(t, 3, msg.Version) - assert.EqualValues(t, expectedOneTimeKey, msg.OneTimeKey) - assert.EqualValues(t, expectedIdKey, msg.IdentityKey) - assert.EqualValues(t, expectedbaseKey, msg.BaseKey) - assert.Equal(t, expectedmessage, msg.Message) + if err != nil { + t.Fatal(err) + } + if msg.Version != 3 { + t.Fatalf("Expected Version to be 3 but go %d", msg.Version) + } + if !bytes.Equal(msg.OneTimeKey, expectedOneTimeKey) { + t.Fatalf("expected '%s' but got '%s'", expectedOneTimeKey, msg.OneTimeKey) + } + if !bytes.Equal(msg.IdentityKey, expectedIdKey) { + t.Fatalf("expected '%s' but got '%s'", expectedIdKey, msg.IdentityKey) + } + if !bytes.Equal(msg.BaseKey, expectedbaseKey) { + t.Fatalf("expected '%s' but got '%s'", expectedbaseKey, msg.BaseKey) + } + if !bytes.Equal(msg.Message, expectedmessage) { + t.Fatalf("expected '%s' but got '%s'", expectedmessage, msg.Message) + } theirIDKey := crypto.Curve25519PublicKey(expectedIdKey) - assert.True(t, msg.CheckFields(&theirIDKey), "field check failed") + checked := msg.CheckFields(&theirIDKey) + if !checked { + t.Fatal("field check failed") + } } func TestPreKeyMessageEncode(t *testing.T) { @@ -40,6 +54,10 @@ func TestPreKeyMessageEncode(t *testing.T) { Message: []byte("message"), } encoded, err := msg.Encode() - assert.NoError(t, err) - assert.Equal(t, expectedRaw, encoded) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(encoded, expectedRaw) { + t.Fatalf("got other than expected:\nExpected:\n%v\nGot:\n%v", expectedRaw, encoded) + } } diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index d58dbb21..f539cce5 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -4,8 +4,8 @@ import ( "encoding/binary" "fmt" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -32,10 +32,10 @@ func (s MegolmSessionExport) Encode() []byte { // Decode populates the struct with the data encoded in input. func (s *MegolmSessionExport) Decode(input []byte) error { if len(input) != 165 { - return fmt.Errorf("decrypt: %w", olm.ErrBadInput) + return fmt.Errorf("decrypt: %w", goolm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) + return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index d04ef15a..c5393f50 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -4,8 +4,8 @@ import ( "encoding/binary" "fmt" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -20,29 +20,29 @@ type MegolmSessionSharing struct { } // Encode returns the encoded message in the correct format with the signature by key appended. -func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) ([]byte, error) { +func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte { output := make([]byte, 229) output[0] = sessionSharingVersion binary.BigEndian.PutUint32(output[1:], s.Counter) copy(output[5:], s.RatchetData[:]) copy(output[133:], key.PublicKey) - signature, err := key.Sign(output[:165]) + signature := key.Sign(output[:165]) copy(output[165:], signature) - return output, err + return output } // VerifyAndDecode verifies the input and populates the struct with the data encoded in input. func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { if len(input) != 229 { - return fmt.Errorf("verify: %w", olm.ErrBadInput) + return fmt.Errorf("verify: %w", goolm.ErrBadInput) } publicKey := crypto.Ed25519PublicKey(input[133:165]) if !publicKey.Verify(input[:165], input[165:]) { - return fmt.Errorf("verify: %w", olm.ErrBadVerification) + return fmt.Errorf("verify: %w", goolm.ErrBadVerification) } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) + return fmt.Errorf("verify: %w", goolm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/olm/chain.go b/crypto/goolm/olm/chain.go new file mode 100644 index 00000000..403637a4 --- /dev/null +++ b/crypto/goolm/olm/chain.go @@ -0,0 +1,258 @@ +package olm + +import ( + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" +) + +const ( + chainKeySeed = 0x02 + messageKeyLength = 32 +) + +// chainKey wraps the index and the public key +type chainKey struct { + Index uint32 `json:"index"` + Key crypto.Curve25519PublicKey `json:"key"` +} + +// advance advances the chain +func (c *chainKey) advance() { + c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed}) + c.Index++ +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read. +func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.Key.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r chainKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort) + } + written, err := r.Key.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle chain key: %w", err) + } + written += libolmpickle.PickleUInt32(r.Index, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain key will have. +func (r chainKey) PickleLen() int { + length := r.Key.PickleLen() + length += libolmpickle.PickleUInt32Len(r.Index) + return length +} + +// senderChain is a chain for sending messages +type senderChain struct { + RKey crypto.Curve25519KeyPair `json:"ratchet_key"` + CKey chainKey `json:"chain_key"` + IsSet bool `json:"set"` +} + +// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. +func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { + return &senderChain{ + RKey: ratchet, + CKey: chainKey{ + Index: 0, + Key: key, + }, + IsSet: true, + } +} + +// advance advances the chain +func (s *senderChain) advance() { + s.CKey.advance() +} + +// ratchetKey returns the ratchet key pair. +func (s senderChain) ratchetKey() crypto.Curve25519KeyPair { + return s.RKey +} + +// chainKey returns the current chainKey. +func (s senderChain) chainKey() chainKey { + return s.CKey +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. +func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.RKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r senderChain) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) + } + written, err := r.RKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle sender chain: %w", err) + } + writtenChain, err := r.CKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle sender chain: %w", err) + } + written += writtenChain + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain will have. +func (r senderChain) PickleLen() int { + length := r.RKey.PickleLen() + length += r.CKey.PickleLen() + return length +} + +// senderChain is a chain for receiving messages +type receiverChain struct { + RKey crypto.Curve25519PublicKey `json:"ratchet_key"` + CKey chainKey `json:"chain_key"` +} + +// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. +func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { + return &receiverChain{ + RKey: ratchet, + CKey: chainKey{ + Index: 0, + Key: chain, + }, + } +} + +// advance advances the chain +func (s *receiverChain) advance() { + s.CKey.advance() +} + +// ratchetKey returns the ratchet public key. +func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey { + return s.RKey +} + +// chainKey returns the current chainKey. +func (s receiverChain) chainKey() chainKey { + return s.CKey +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. +func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.RKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r receiverChain) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) + } + written, err := r.RKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle sender chain: %w", err) + } + writtenChain, err := r.CKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle sender chain: %w", err) + } + written += writtenChain + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain will have. +func (r receiverChain) PickleLen() int { + length := r.RKey.PickleLen() + length += r.CKey.PickleLen() + return length +} + +// messageKey wraps the index and the key of a message +type messageKey struct { + Index uint32 `json:"index"` + Key []byte `json:"key"` +} + +// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read. +func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength) + if err != nil { + return 0, err + } + m.Key = ratchetKey + curPos += readBytes + keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos]) + if err != nil { + return 0, err + } + curPos += readBytes + m.Index = keyID + return curPos, nil +} + +// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (m messageKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < m.PickleLen() { + return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort) + } + written := 0 + if len(m.Key) != messageKeyLength { + written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target) + } else { + written += libolmpickle.PickleBytes(m.Key, target) + } + written += libolmpickle.PickleUInt32(m.Index, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled message key will have. +func (r messageKey) PickleLen() int { + length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength)) + length += libolmpickle.PickleUInt32Len(r.Index) + return length +} diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/olm/olm.go similarity index 51% rename from crypto/goolm/ratchet/olm.go rename to crypto/goolm/olm/olm.go index 9901ada8..299ec7c4 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/olm/olm.go @@ -1,19 +1,16 @@ -// Package ratchet provides the ratchet used by the olm protocol -package ratchet +// olm provides the ratchet used by the olm protocol +package olm import ( - "crypto/hmac" - "crypto/sha256" "fmt" "io" - "golang.org/x/crypto/hkdf" - - "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -30,8 +27,6 @@ const ( sharedKeyLength = 32 ) -var olmKeysKDFInfo = []byte("OLM_KEYS") - // KdfInfo has the infos used for the kdf var KdfInfo = struct { Root []byte @@ -41,6 +36,8 @@ var KdfInfo = struct { Ratchet: []byte("OLM_RATCHET"), } +var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS")) + // Ratchet represents the olm ratchet as described in // // https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md @@ -67,12 +64,13 @@ type Ratchet struct { // New creates a new ratchet, setting the kdfInfos and cipher. func New() *Ratchet { - return &Ratchet{} + r := &Ratchet{} + return r } // InitializeAsBob initializes this ratchet from a receiving point of view (only first message). func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error { - derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root) + derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return err @@ -85,7 +83,7 @@ func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Cu // InitializeAsAlice initializes this ratchet from a sending point of view (only first message). func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error { - derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root) + derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return err @@ -96,11 +94,11 @@ func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Cu return nil } -// Encrypt encrypts the message in a message.Message with MAC. -func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { +// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations. +func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { var err error if !r.SenderChains.IsSet { - newRatchetKey, err := crypto.Curve25519GenerateKey() + newRatchetKey, err := crypto.Curve25519GenerateKey(reader) if err != nil { return nil, err } @@ -115,11 +113,7 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { messageKey := r.createMessageKeys(r.SenderChains.chainKey()) r.SenderChains.advance() - cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) - if err != nil { - return nil, err - } - encryptedText, err := cipher.Encrypt(plaintext) + encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -130,10 +124,15 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { message.RatchetKey = r.SenderChains.ratchetKey().PublicKey message.Ciphertext = encryptedText //creating the mac is done in encode - return message.EncodeAndMAC(cipher) + output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher) + if err != nil { + return nil, err + } + + return output, nil } -// Decrypt decrypts the ciphertext and verifies the MAC. +// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations. func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { message := &message.Message{} //The mac is not verified here, as we do not know the key yet @@ -142,10 +141,10 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion) + return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { - return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) } var receiverChainFromMessage *receiverChain for curChainIndex := range r.ReceiverChains { @@ -154,40 +153,53 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { break } } + var result []byte if receiverChainFromMessage == nil { //Advancing the chain is done in this method - return r.decryptForNewChain(message, input) + result, err = r.decryptForNewChain(message, input) + if err != nil { + return nil, err + } } else if receiverChainFromMessage.chainKey().Index > message.Counter { // No need to advance the chain // Chain already advanced beyond the key for this message // Check if the message keys are in the skipped key list. + foundSkippedKey := false for curSkippedIndex := range r.SkippedMessageKeys { - if message.Counter != r.SkippedMessageKeys[curSkippedIndex].MKey.Index { - continue - } - - // Found the key for this message. Check the MAC. - if cipher, err := aessha2.NewAESSHA2(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, olmKeysKDFInfo); err != nil { - return nil, err - } else if verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, cipher, input); err != nil { - return nil, err - } else if !verified { - return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) - } else if result, err := cipher.Decrypt(message.Ciphertext); err != nil { - return nil, fmt.Errorf("cipher decrypt: %w", err) - } else if len(result) != 0 { - // Remove the key from the skipped keys now that we've - // decoded the message it corresponds to. - r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] - r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] - return result, nil + if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index { + // Found the key for this message. Check the MAC. + verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) + if err != nil { + return nil, err + } + if !verified { + return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC) + } + result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) + if err != nil { + return nil, fmt.Errorf("cipher decrypt: %w", err) + } + if len(result) != 0 { + // Remove the key from the skipped keys now that we've + // decoded the message it corresponds to. + r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] + r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] + } + foundSkippedKey = true } } - return nil, fmt.Errorf("decrypt: %w", olm.ErrMessageKeyNotFound) + if !foundSkippedKey { + return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound) + } } else { //Advancing the chain is done in this method - return r.decryptForExistingChain(receiverChainFromMessage, message, input) + result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input) + if err != nil { + return nil, err + } } + + return result, nil } // advanceRootKey created the next root key and returns the next chainKey @@ -196,7 +208,7 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc if err != nil { return nil, err } - derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, r.RootKey, KdfInfo.Ratchet) + derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return nil, err @@ -207,22 +219,20 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc // createMessageKeys returns the messageKey derived from the chainKey func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey { - hash := hmac.New(sha256.New, chainKey.Key) - hash.Write([]byte{messageKeySeed}) - return messageKey{ - Key: hash.Sum(nil), - Index: chainKey.Index, - } + res := messageKey{} + res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed}) + res.Index = chainKey.Index + return res } // decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified. func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) { if message.Counter < chain.CKey.Index { - return nil, fmt.Errorf("decrypt: %w", olm.ErrChainTooHigh) + return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh) } // Limit the number of hashes we're prepared to compute if message.Counter-chain.CKey.Index > maxMessageGap { - return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrMsgIndexTooHigh) + return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh) } for chain.CKey.Index < message.Counter { messageKey := r.createMessageKeys(chain.chainKey()) @@ -235,18 +245,14 @@ func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message } messageKey := r.createMessageKeys(chain.chainKey()) chain.advance() - cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) - if err != nil { - return nil, err - } - verified, err := message.VerifyMACInline(messageKey.Key, cipher, rawMessage) + verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage) if err != nil { return nil, err } if !verified { - return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrBadMAC) + return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC) } - return cipher.Decrypt(message.Ciphertext) + return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) } // decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key. @@ -254,11 +260,11 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte // They shouldn't move to a new chain until we've sent them a message // acknowledging the last one if !r.SenderChains.IsSet { - return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrProtocolViolation) + return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation) } // Limit the number of hashes we're prepared to compute if message.Counter > maxMessageGap { - return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrMsgIndexTooHigh) + return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh) } newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey) @@ -275,88 +281,152 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte */ r.SenderChains = senderChain{} - return r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) + decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) + if err != nil { + return nil, err + } + return decrypted, nil } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(r, olmPickleVersion, key) + return utilities.PickleAsJSON(r, olmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(r, pickled, key, olmPickleVersion) + return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion) } -// UnpickleLibOlm unpickles the unencryted value and populates the [Ratchet] -// accordingly. -func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder, includesChainIndex bool) error { - if err := r.RootKey.UnpickleLibOlm(decoder); err != nil { - return err - } - senderChainsCount, err := decoder.ReadUInt32() +// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. +func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) { + //read ratchet data + curPos := 0 + readBytes, err := r.RootKey.UnpickleLibOlm(value) if err != nil { - return err + return 0, err } - - for i := uint32(0); i < senderChainsCount; i++ { + curPos += readBytes + countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain + if err != nil { + return 0, err + } + curPos += readBytes + for i := uint32(0); i < countSenderChains; i++ { if i == 0 { - // only the first sender key is stored - err = r.SenderChains.UnpickleLibOlm(decoder) + //only first is stored + readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes r.SenderChains.IsSet = true } else { - // just eat the values - err = (&senderChain{}).UnpickleLibOlm(decoder) + dummy := senderChain{} + readBytes, err := dummy.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes } + } + countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain + if err != nil { + return 0, err + } + curPos += readBytes + r.ReceiverChains = make([]receiverChain, countReceivChains) + for i := uint32(0); i < countReceivChains; i++ { + readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:]) if err != nil { - return err + return 0, err } + curPos += readBytes } - - receiverChainCount, err := decoder.ReadUInt32() + countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys if err != nil { - return err + return 0, err } - r.ReceiverChains = make([]receiverChain, receiverChainCount) - for i := uint32(0); i < receiverChainCount; i++ { - if err := r.ReceiverChains[i].UnpickleLibOlm(decoder); err != nil { - return err + curPos += readBytes + r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys) + for i := uint32(0); i < countSkippedMessageKeys; i++ { + readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err } + curPos += readBytes } - - skippedMessageKeysCount, err := decoder.ReadUInt32() - if err != nil { - return err - } - r.SkippedMessageKeys = make([]skippedMessageKey, skippedMessageKeysCount) - for i := uint32(0); i < skippedMessageKeysCount; i++ { - if err := r.SkippedMessageKeys[i].UnpickleLibOlm(decoder); err != nil { - return err - } - } - - // pickle version 0x80000001 includes a chain index; pickle version 1 does not. + // pickle v 0x80000001 includes a chain index; pickle v1 does not. if includesChainIndex { - _, err = decoder.ReadUInt32() - return err + _, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes } - return nil + return curPos, nil } -// PickleLibOlm pickles the ratchet into the encoder. -func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) { - r.RootKey.PickleLibOlm(encoder) - r.SenderChains.PickleLibOlm(encoder) - - // Receiver Chains - encoder.WriteUInt32(uint32(len(r.ReceiverChains))) +// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r Ratchet) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort) + } + written, err := r.RootKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle ratchet: %w", err) + } + if r.SenderChains.IsSet { + written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1 + writtenSender, err := r.SenderChains.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle ratchet: %w", err) + } + written += writtenSender + } else { + written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain + } + written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:]) for _, curChain := range r.ReceiverChains { - curChain.PickleLibOlm(encoder) + writtenChain, err := curChain.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle ratchet: %w", err) + } + written += writtenChain } - - // Skipped Message Keys - encoder.WriteUInt32(uint32(len(r.SkippedMessageKeys))) + written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:]) for _, curChain := range r.SkippedMessageKeys { - curChain.PickleLibOlm(encoder) + writtenChain, err := curChain.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle ratchet: %w", err) + } + written += writtenChain } + return written, nil +} + +// PickleLen returns the actual number of bytes the pickled ratchet will have. +func (r Ratchet) PickleLen() int { + length := r.RootKey.PickleLen() + if r.SenderChains.IsSet { + length += libolmpickle.PickleUInt32Len(1) + length += r.SenderChains.PickleLen() + } else { + length += libolmpickle.PickleUInt32Len(0) + } + length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains))) + length += len(r.ReceiverChains) * receiverChain{}.PickleLen() + length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys))) + length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen() + return length +} + +// PickleLen returns the minimum number of bytes the pickled ratchet must have. +func (r Ratchet) PickleLenMin() int { + length := r.RootKey.PickleLen() + length += libolmpickle.PickleUInt32Len(0) + length += libolmpickle.PickleUInt32Len(0) + length += libolmpickle.PickleUInt32Len(0) + return length } diff --git a/crypto/goolm/olm/olm_test.go b/crypto/goolm/olm/olm_test.go new file mode 100644 index 00000000..974ffc5e --- /dev/null +++ b/crypto/goolm/olm/olm_test.go @@ -0,0 +1,186 @@ +package olm_test + +import ( + "bytes" + "encoding/json" + "testing" + + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/olm" +) + +var ( + sharedSecret = []byte("A secret") +) + +func initializeRatchets() (*olm.Ratchet, *olm.Ratchet, error) { + olm.KdfInfo = struct { + Root []byte + Ratchet []byte + }{ + Root: []byte("Olm"), + Ratchet: []byte("OlmRatchet"), + } + olm.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) + aliceRatchet := olm.New() + bobRatchet := olm.New() + + aliceKey, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + return nil, nil, err + } + + aliceRatchet.InitializeAsAlice(sharedSecret, aliceKey) + bobRatchet.InitializeAsBob(sharedSecret, aliceKey.PublicKey) + return aliceRatchet, bobRatchet, nil +} + +func TestSendReceive(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + + plainText := []byte("Hello Bob") + + //Alice sends Bob a message + encryptedMessage, err := aliceRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + + decrypted, err := bobRatchet.Decrypt(encryptedMessage) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + + //Bob sends Alice a message + plainText = []byte("Hello Alice") + encryptedMessage, err = bobRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err = aliceRatchet.Decrypt(encryptedMessage) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } +} + +func TestOutOfOrder(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + + plainText1 := []byte("First Message") + plainText2 := []byte("Second Messsage. A bit longer than the first.") + + /* Alice sends Bob two messages and they arrive out of order */ + message1Encrypted, err := aliceRatchet.Encrypt(plainText1, nil) + if err != nil { + t.Fatal(err) + } + message2Encrypted, err := aliceRatchet.Encrypt(plainText2, nil) + if err != nil { + t.Fatal(err) + } + + decrypted2, err := bobRatchet.Decrypt(message2Encrypted) + if err != nil { + t.Fatal(err) + } + decrypted1, err := bobRatchet.Decrypt(message1Encrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText1, decrypted1) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText1, decrypted1) + } + if !bytes.Equal(plainText2, decrypted2) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText2, decrypted2) + } +} + +func TestMoreMessages(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + plainText := []byte("These 15 bytes") + for i := 0; i < 8; i++ { + messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + } + for i := 0; i < 8; i++ { + messageEncrypted, err := bobRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := aliceRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + } + messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } +} + +func TestJSONEncoding(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + marshaled, err := json.Marshal(aliceRatchet) + if err != nil { + t.Fatal(err) + } + + newRatcher := olm.Ratchet{} + err = json.Unmarshal(marshaled, &newRatcher) + if err != nil { + t.Fatal(err) + } + + plainText := []byte("These 15 bytes") + + messageEncrypted, err := newRatcher.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + +} diff --git a/crypto/goolm/olm/skipped_message.go b/crypto/goolm/olm/skipped_message.go new file mode 100644 index 00000000..944337f6 --- /dev/null +++ b/crypto/goolm/olm/skipped_message.go @@ -0,0 +1,55 @@ +package olm + +import ( + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" +) + +// skippedMessageKey stores a skipped message key +type skippedMessageKey struct { + RKey crypto.Curve25519PublicKey `json:"ratchet_key"` + MKey messageKey `json:"message_key"` +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. +func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.RKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) + } + written, err := r.RKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle sender chain: %w", err) + } + writtenChain, err := r.MKey.PickleLibOlm(target) + if err != nil { + return 0, fmt.Errorf("pickle sender chain: %w", err) + } + written += writtenChain + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain will have. +func (r skippedMessageKey) PickleLen() int { + length := r.RKey.PickleLen() + length += r.MKey.PickleLen() + return length +} diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index cdb20eb1..d08e09f4 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -2,13 +2,14 @@ package pk import ( "encoding/base64" + "errors" "fmt" - "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/id" ) @@ -24,7 +25,7 @@ type Decryption struct { // NewDecryption returns a new Decryption with a new generated key pair. func NewDecryption() (*Decryption, error) { - keyPair, err := crypto.Curve25519GenerateKey() + keyPair, err := crypto.Curve25519GenerateKey(nil) if err != nil { return nil, err } @@ -55,67 +56,110 @@ func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey { } // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. -func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) { - if keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)); err != nil { +func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { + keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) + if err != nil { return nil, err - } else if sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded); err != nil { - return nil, err - } else if decodedMAC, err := goolmbase64.Decode(mac); err != nil { - return nil, err - } else if cipher, err := aessha2.NewAESSHA2(sharedSecret, nil); err != nil { - return nil, err - } else if verified, err := cipher.VerifyMAC(ciphertext, decodedMAC); err != nil { - return nil, err - } else if !verified { - return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) - } else { - return cipher.Decrypt(ciphertext) } + sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) + if err != nil { + return nil, err + } + decodedMAC, err := goolm.Base64Decode(mac) + if err != nil { + return nil, err + } + cipher := cipher.NewAESSHA256(nil) + verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC) + if err != nil { + return nil, err + } + if !verified { + return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) + } + plaintext, err := cipher.Decrypt(sharedSecret, ciphertext) + if err != nil { + return nil, err + } + return plaintext, nil } // PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(a, decryptionPickleVersionJSON, key) + return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key) } // UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) + return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Decryption) Unpickle(pickled, key []byte) error { - decrypted, err := libolmpickle.Unpickle(key, pickled) + decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err } - return a.UnpickleLibOlm(decrypted) + _, err = a.UnpickleLibOlm(decrypted) + return err } // UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read. -func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { - decoder := libolmpickle.NewDecoder(unpickled) - pickledVersion, err := decoder.ReadUInt32() +func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) if err != nil { - return err + return 0, err } - if pickledVersion == decryptionPickleVersionLibOlm { - return a.KeyPair.UnpickleLibOlm(decoder) - } else { - return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) + switch pickledVersion { + case decryptionPickleVersionLibOlm: + default: + return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) } + readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil } // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). func (a Decryption) Pickle(key []byte) ([]byte, error) { - return libolmpickle.Pickle(key, a.PickleLibOlm()) + pickeledBytes := make([]byte, a.PickleLen()) + written, err := a.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err + } + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil } -// PickleLibOlm pickles the [Decryption] into the encoder. -func (a Decryption) PickleLibOlm() []byte { - encoder := libolmpickle.NewEncoder() - encoder.WriteUInt32(decryptionPickleVersionLibOlm) - a.KeyPair.PickleLibOlm(encoder) - return encoder.Bytes() +// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (a Decryption) PickleLibOlm(target []byte) (int, error) { + if len(target) < a.PickleLen() { + return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort) + } + written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) + writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle Decryption: %w", err) + } + written += writtenKey + return written, nil +} + +// PickleLen returns the number of bytes the pickled Decryption will have. +func (a Decryption) PickleLen() int { + length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm) + length += a.KeyPair.PickleLen() + return length } diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 2897d9b0..dc50a6bb 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -5,9 +5,9 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" ) // Encryption is used to encrypt pk messages @@ -36,14 +36,14 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat if err != nil { return nil, nil, err } - cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) + cipher := cipher.NewAESSHA256(nil) + ciphertext, err = cipher.Encrypt(sharedSecret, plaintext) if err != nil { return nil, nil, err } - ciphertext, err = cipher.Encrypt(plaintext) + mac, err = cipher.MAC(sharedSecret, ciphertext) if err != nil { return nil, nil, err } - mac, err = cipher.MAC(ciphertext) - return ciphertext, goolmbase64.Encode(mac), err + return ciphertext, goolm.Base64Encode(mac), nil } diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index 4b247430..7ac524be 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -1,13 +1,14 @@ package pk_test import ( + "bytes" "encoding/base64" "testing" - "github.com/stretchr/testify/assert" - + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/id" ) func TestEncryptionDecryption(t *testing.T) { @@ -26,20 +27,34 @@ func TestEncryptionDecryption(t *testing.T) { } bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) - assert.NoError(t, err) - assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct") - assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { + t.Fatal("public key not correct") + } + if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { + t.Fatal("private key not correct") + } encryption, err := pk.NewEncryption(decryption.PublicKey()) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } plaintext := []byte("This is a test") ciphertext, mac, err := encryption.Encrypt(plaintext, bobPrivate) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } - decrypted, err := decryption.Decrypt(bobPublic, mac, ciphertext) - assert.NoError(t, err) - assert.EqualValues(t, plaintext, decrypted, "message not equal") + decrypted, err := decryption.Decrypt(ciphertext, mac, id.Curve25519(bobPublic)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decrypted, plaintext) { + t.Fatal("message not equal") + } } func TestSigning(t *testing.T) { @@ -52,20 +67,29 @@ func TestSigning(t *testing.T) { message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.") signing, _ := pk.NewSigningFromSeed(seed) signature, err := signing.Sign(message) - assert.NoError(t, err) - signatureDecoded, err := base64.RawStdEncoding.DecodeString(string(signature)) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } + signatureDecoded, err := goolm.Base64Decode(signature) + if err != nil { + t.Fatal(err) + } pubKeyEncoded := signing.PublicKey() pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKeyEncoded)) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } pubKey := crypto.Ed25519PublicKey(pubKeyDecoded) verified := pubKey.Verify(message, signatureDecoded) - assert.True(t, verified, "signature did not verify") - + if !verified { + t.Fatal("signature did not verify") + } copy(signatureDecoded[0:], []byte("m")) verified = pubKey.Verify(message, signatureDecoded) - assert.False(t, verified, "signature verified with wrong message") + if verified { + t.Fatal("signature did verify") + } } func TestDecryptionPickling(t *testing.T) { @@ -77,19 +101,37 @@ func TestDecryptionPickling(t *testing.T) { } alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) - assert.NoError(t, err) - assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct") - assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { + t.Fatal("public key not correct") + } + if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { + t.Fatal("private key not correct") + } pickleKey := []byte("secret_key") expectedPickle := []byte("qx37WTQrjZLz5tId/uBX9B3/okqAbV1ofl9UnHKno1eipByCpXleAAlAZoJgYnCDOQZDQWzo3luTSfkF9pU1mOILCbbouubs6TVeDyPfgGD9i86J8irHjA") pickled, err := decryption.Pickle(pickleKey) - assert.NoError(t, err) - assert.EqualValues(t, expectedPickle, pickled, "pickle not as expected") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expectedPickle, pickled) { + t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle) + } newDecription, err := pk.NewDecryption() - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } err = newDecription.Unpickle(pickled, pickleKey) - assert.NoError(t, err) - assert.EqualValues(t, alicePublic, newDecription.PublicKey(), "public key not correct") - assert.EqualValues(t, alicePrivate, newDecription.PrivateKey(), "private key not correct") + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) { + t.Fatal("public key not correct") + } + if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) { + t.Fatal("private key not correct") + } } diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go deleted file mode 100644 index 0e27b568..00000000 --- a/crypto/goolm/pk/register.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package pk - -import "maunium.net/go/mautrix/crypto/olm" - -func Register() { - olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { - return NewSigningFromSeed(seed) - } - olm.InitNewPKSigning = func() (olm.PKSigning, error) { - return NewSigning() - } - olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { - return NewDecryptionFromPrivate(privateKey) - } -} diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index 61b31b6f..a98330d5 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -7,8 +7,8 @@ import ( "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/id" ) @@ -48,8 +48,8 @@ func (s Signing) PublicKey() id.Ed25519 { // Sign returns the signature of the message base64 encoded. func (s Signing) Sign(message []byte) ([]byte, error) { - signature, err := s.keyPair.Sign(message) - return goolmbase64.Encode(signature), err + signature := s.keyPair.Sign(message) + return goolm.Base64Encode(signature), nil } // SignJSON creates a signature for the given object after encoding it to @@ -62,5 +62,8 @@ func (s Signing) SignJSON(obj any) (string, error) { objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - return string(signature), err + if err != nil { + return "", err + } + return string(signature), nil } diff --git a/crypto/goolm/ratchet/chain.go b/crypto/goolm/ratchet/chain.go deleted file mode 100644 index 5deb90f5..00000000 --- a/crypto/goolm/ratchet/chain.go +++ /dev/null @@ -1,170 +0,0 @@ -package ratchet - -import ( - "crypto/hmac" - "crypto/sha256" - - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" -) - -const ( - chainKeySeed = 0x02 - messageKeyLength = 32 -) - -// chainKey wraps the index and the public key -type chainKey struct { - Index uint32 `json:"index"` - Key crypto.Curve25519PublicKey `json:"key"` -} - -// advance advances the chain -func (c *chainKey) advance() { - hash := hmac.New(sha256.New, c.Key) - hash.Write([]byte{chainKeySeed}) - c.Key = hash.Sum(nil) - c.Index++ -} - -// UnpickleLibOlm unpickles the unencryted value and populates the chain key accordingly. -func (r *chainKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - err := r.Key.UnpickleLibOlm(decoder) - if err != nil { - return err - } - r.Index, err = decoder.ReadUInt32() - return err -} - -// PickleLibOlm pickles the chain key into the encoder. -func (r chainKey) PickleLibOlm(encoder *libolmpickle.Encoder) { - r.Key.PickleLibOlm(encoder) - encoder.WriteUInt32(r.Index) -} - -// senderChain is a chain for sending messages -type senderChain struct { - RKey crypto.Curve25519KeyPair `json:"ratchet_key"` - CKey chainKey `json:"chain_key"` - IsSet bool `json:"set"` -} - -// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. -func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { - return &senderChain{ - RKey: ratchet, - CKey: chainKey{ - Index: 0, - Key: key, - }, - IsSet: true, - } -} - -// advance advances the chain -func (s *senderChain) advance() { - s.CKey.advance() -} - -// ratchetKey returns the ratchet key pair. -func (s senderChain) ratchetKey() crypto.Curve25519KeyPair { - return s.RKey -} - -// chainKey returns the current chainKey. -func (s senderChain) chainKey() chainKey { - return s.CKey -} - -// UnpickleLibOlm unpickles the unencryted value and populates the sender chain -// accordingly. -func (r *senderChain) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - if err := r.RKey.UnpickleLibOlm(decoder); err != nil { - return err - } - return r.CKey.UnpickleLibOlm(decoder) -} - -// PickleLibOlm pickles the sender chain into the encoder. -func (r senderChain) PickleLibOlm(encoder *libolmpickle.Encoder) { - if r.IsSet { - encoder.WriteUInt32(1) // Length of the sender chain (1 if set) - r.RKey.PickleLibOlm(encoder) - r.CKey.PickleLibOlm(encoder) - } else { - encoder.WriteUInt32(0) - } -} - -// senderChain is a chain for receiving messages -type receiverChain struct { - RKey crypto.Curve25519PublicKey `json:"ratchet_key"` - CKey chainKey `json:"chain_key"` -} - -// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. -func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { - return &receiverChain{ - RKey: ratchet, - CKey: chainKey{ - Index: 0, - Key: chain, - }, - } -} - -// advance advances the chain -func (s *receiverChain) advance() { - s.CKey.advance() -} - -// ratchetKey returns the ratchet public key. -func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey { - return s.RKey -} - -// chainKey returns the current chainKey. -func (s receiverChain) chainKey() chainKey { - return s.CKey -} - -// UnpickleLibOlm unpickles the unencryted value and populates the chain accordingly. -func (r *receiverChain) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { - if err := r.RKey.UnpickleLibOlm(decoder); err != nil { - return err - } - return r.CKey.UnpickleLibOlm(decoder) -} - -// PickleLibOlm pickles the receiver chain into the encoder. -func (r receiverChain) PickleLibOlm(encoder *libolmpickle.Encoder) { - r.RKey.PickleLibOlm(encoder) - r.CKey.PickleLibOlm(encoder) -} - -// messageKey wraps the index and the key of a message -type messageKey struct { - Index uint32 `json:"index"` - Key []byte `json:"key"` -} - -// UnpickleLibOlm unpickles the unencryted value and populates the message key -// accordingly. -func (m *messageKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { - if m.Key, err = decoder.ReadBytes(messageKeyLength); err != nil { - return - } - m.Index, err = decoder.ReadUInt32() - return -} - -// PickleLibOlm pickles the message key into the encoder. -func (m messageKey) PickleLibOlm(encoder *libolmpickle.Encoder) { - if len(m.Key) == messageKeyLength { - encoder.Write(m.Key) - } else { - encoder.WriteEmptyBytes(messageKeyLength) - } - encoder.WriteUInt32(m.Index) -} diff --git a/crypto/goolm/ratchet/olm_test.go b/crypto/goolm/ratchet/olm_test.go deleted file mode 100644 index 2bf7ea0a..00000000 --- a/crypto/goolm/ratchet/olm_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package ratchet_test - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/ratchet" -) - -var ( - sharedSecret = []byte("A secret") -) - -func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { - ratchet.KdfInfo = struct { - Root []byte - Ratchet []byte - }{ - Root: []byte("Olm"), - Ratchet: []byte("OlmRatchet"), - } - aliceRatchet := ratchet.New() - bobRatchet := ratchet.New() - - aliceKey, err := crypto.Curve25519GenerateKey() - if err != nil { - return nil, nil, err - } - - aliceRatchet.InitializeAsAlice(sharedSecret, aliceKey) - bobRatchet.InitializeAsBob(sharedSecret, aliceKey.PublicKey) - return aliceRatchet, bobRatchet, nil -} - -func TestSendReceive(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - assert.NoError(t, err) - - plainText := []byte("Hello Bob") - - //Alice sends Bob a message - encryptedMessage, err := aliceRatchet.Encrypt(plainText) - assert.NoError(t, err) - - decrypted, err := bobRatchet.Decrypt(encryptedMessage) - assert.NoError(t, err) - assert.Equal(t, plainText, decrypted) - - //Bob sends Alice a message - plainText = []byte("Hello Alice") - encryptedMessage, err = bobRatchet.Encrypt(plainText) - assert.NoError(t, err) - decrypted, err = aliceRatchet.Decrypt(encryptedMessage) - assert.NoError(t, err) - assert.Equal(t, plainText, decrypted) -} - -func TestOutOfOrder(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - assert.NoError(t, err) - - plainText1 := []byte("First Message") - plainText2 := []byte("Second Messsage. A bit longer than the first.") - - /* Alice sends Bob two messages and they arrive out of order */ - message1Encrypted, err := aliceRatchet.Encrypt(plainText1) - assert.NoError(t, err) - message2Encrypted, err := aliceRatchet.Encrypt(plainText2) - assert.NoError(t, err) - - decrypted2, err := bobRatchet.Decrypt(message2Encrypted) - assert.NoError(t, err) - decrypted1, err := bobRatchet.Decrypt(message1Encrypted) - assert.NoError(t, err) - assert.Equal(t, plainText1, decrypted1) - assert.Equal(t, plainText2, decrypted2) -} - -func TestMoreMessages(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - assert.NoError(t, err) - plainText := []byte("These 15 bytes") - for i := 0; i < 8; i++ { - messageEncrypted, err := aliceRatchet.Encrypt(plainText) - assert.NoError(t, err) - - decrypted, err := bobRatchet.Decrypt(messageEncrypted) - assert.NoError(t, err) - assert.Equal(t, plainText, decrypted) - } - for i := 0; i < 8; i++ { - messageEncrypted, err := bobRatchet.Encrypt(plainText) - assert.NoError(t, err) - - decrypted, err := aliceRatchet.Decrypt(messageEncrypted) - assert.NoError(t, err) - assert.Equal(t, plainText, decrypted) - } - messageEncrypted, err := aliceRatchet.Encrypt(plainText) - assert.NoError(t, err) - decrypted, err := bobRatchet.Decrypt(messageEncrypted) - assert.NoError(t, err) - assert.Equal(t, plainText, decrypted) -} - -func TestJSONEncoding(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - assert.NoError(t, err) - marshaled, err := json.Marshal(aliceRatchet) - assert.NoError(t, err) - - newRatcher := ratchet.Ratchet{} - err = json.Unmarshal(marshaled, &newRatcher) - assert.NoError(t, err) - - plainText := []byte("These 15 bytes") - - messageEncrypted, err := newRatcher.Encrypt(plainText) - assert.NoError(t, err) - decrypted, err := bobRatchet.Decrypt(messageEncrypted) - assert.NoError(t, err) - assert.Equal(t, plainText, decrypted) -} diff --git a/crypto/goolm/ratchet/skipped_message.go b/crypto/goolm/ratchet/skipped_message.go deleted file mode 100644 index 2ffaee7b..00000000 --- a/crypto/goolm/ratchet/skipped_message.go +++ /dev/null @@ -1,27 +0,0 @@ -package ratchet - -import ( - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" -) - -// skippedMessageKey stores a skipped message key -type skippedMessageKey struct { - RKey crypto.Curve25519PublicKey `json:"ratchet_key"` - MKey messageKey `json:"message_key"` -} - -// UnpickleLibOlm unpickles the unencryted value and populates the skipped -// message keys accordingly. -func (r *skippedMessageKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { - if err = r.RKey.UnpickleLibOlm(decoder); err != nil { - return - } - return r.MKey.UnpickleLibOlm(decoder) -} - -// PickleLibOlm pickles the skipped message key into the encoder. -func (r skippedMessageKey) PickleLibOlm(encoder *libolmpickle.Encoder) { - r.RKey.PickleLibOlm(encoder) - r.MKey.PickleLibOlm(encoder) -} diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go deleted file mode 100644 index 800f567f..00000000 --- a/crypto/goolm/register.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package goolm - -import ( - "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/olm" -) - -func Register() { - olm.Driver = "goolm" - - olm.GetVersion = func() (major, minor, patch uint8) { - return 3, 2, 15 - } - olm.SetPickleKeyImpl = func(key []byte) { - panic("gob and json encoding is deprecated and not supported with goolm") - } - - account.Register() - pk.Register() - session.Register() -} diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 7ccbd26d..165f7f16 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -2,14 +2,16 @@ package session import ( "encoding/base64" + "errors" "fmt" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/id" ) @@ -26,14 +28,10 @@ type MegolmInboundSession struct { SigningKeyVerified bool `json:"signing_key_verified"` //not used for now } -// Ensure that MegolmInboundSession implements the [olm.InboundGroupSession] -// interface. -var _ olm.InboundGroupSession = (*MegolmInboundSession)(nil) - // NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message. func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { var err error - input, err = goolmbase64.Decode(input) + input, err = goolm.Base64Decode(input) if err != nil { return nil, err } @@ -57,7 +55,7 @@ func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { // NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message. func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) { var err error - input, err = goolmbase64.Decode(input) + input, err = goolm.Base64Decode(input) if err != nil { return nil, err } @@ -80,7 +78,7 @@ func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, err // MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput) } a := &MegolmInboundSession{} err := a.Unpickle(pickled, key) @@ -91,7 +89,7 @@ func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession } // getRatchet tries to find the correct ratchet for a messageIndex. -func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { +func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { // pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) { o.Ratchet.AdvanceTo(messageIndex) @@ -99,7 +97,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) + return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -109,14 +107,11 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } // Decrypt decrypts a base64 encoded group message. -func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) { - if len(ciphertext) == 0 { - return nil, 0, olm.ErrEmptyInput - } +func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) { if o.SigningKey == nil { - return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) } - decoded, err := goolmbase64.Decode(ciphertext) + decoded, err := goolm.Base64Decode(ciphertext) if err != nil { return nil, 0, err } @@ -126,16 +121,16 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { - return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) } // verify signature verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded) if !verifiedSignature { - return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadSignature) + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature) } targetRatch, err := o.getRatchet(msg.MessageIndex) @@ -148,33 +143,27 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) return nil, 0, err } o.SigningKeyVerified = true - return decrypted, uint(msg.MessageIndex), nil + return decrypted, msg.MessageIndex, nil } -// ID returns the base64 endoded signing key -func (o *MegolmInboundSession) ID() id.SessionID { +// SessionID returns the base64 endoded signing key +func (o MegolmInboundSession) SessionID() id.SessionID { return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey)) } // PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) +func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) + return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) } -// Export returns the base64-encoded ratchet key for this session, at the given -// index, in a format which can be used by -// InboundGroupSession.InboundGroupSessionImport(). Encrypts the -// InboundGroupSession using the supplied key. Returns error on failure. -// if we do not have a session key corresponding to the given index (ie, it was -// sent before the session key was shared with us) the error will be -// returned. -func (o *MegolmInboundSession) Export(messageIndex uint32) ([]byte, error) { +// SessionExportMessage creates an base64 encoded export of the session. +func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) { ratchet, err := o.getRatchet(messageIndex) if err != nil { return nil, err @@ -185,75 +174,103 @@ func (o *MegolmInboundSession) Export(messageIndex uint32) ([]byte, error) { // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return olm.ErrNoKeyProvided - } else if len(pickled) == 0 { - return olm.ErrEmptyInput - } - decrypted, err := libolmpickle.Unpickle(key, pickled) + decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err } - return o.UnpickleLibOlm(decrypted) + _, err = o.UnpickleLibOlm(decrypted) + return err } -// UnpickleLibOlm unpickles the unencryted value and populates the [Session] -// accordingly. -func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { - decoder := libolmpickle.NewDecoder(value) - pickledVersion, err := decoder.ReadUInt32() +// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. +func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) if err != nil { - return err + return 0, err } - if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + switch pickledVersion { + case megolmInboundSessionPickleVersionLibOlm, 1: + default: + return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) } - - if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { - return err - } else if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { - return err - } else if err = o.SigningKey.UnpickleLibOlm(decoder); err != nil { - return err + readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err } - + curPos += readBytes + readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes if pickledVersion == 1 { // pickle v1 had no signing_key_verified field (all keyshares were verified at import time) o.SigningKeyVerified = true } else { - o.SigningKeyVerified, err = decoder.ReadBool() - return err + o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes } - return nil + return curPos, nil } // Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm(). -func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided +func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, o.PickleLen()) + written, err := o.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err } - return libolmpickle.Pickle(key, o.PickleLibOlm()) + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil } -// PickleLibOlm pickles the session returning the raw bytes. -func (o *MegolmInboundSession) PickleLibOlm() []byte { - encoder := libolmpickle.NewEncoder() - encoder.WriteUInt32(megolmInboundSessionPickleVersionLibOlm) - o.InitialRatchet.PickleLibOlm(encoder) - o.Ratchet.PickleLibOlm(encoder) - o.SigningKey.PickleLibOlm(encoder) - encoder.WriteBool(o.SigningKeyVerified) - return encoder.Bytes() +// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { + if len(target) < o.PickleLen() { + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort) + } + written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) + writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) + } + written += writtenInitRatchet + writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) + } + written += writtenRatchet + writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) + } + written += writtenPubKey + written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:]) + return written, nil } -// FirstKnownIndex returns the first message index we know how to decrypt. -func (s *MegolmInboundSession) FirstKnownIndex() uint32 { - return s.InitialRatchet.Counter -} - -// IsVerified check if the session has been verified as a valid session. (A -// session is verified either because the original session share was signed, or -// because we have subsequently successfully decrypted a message.) -func (s *MegolmInboundSession) IsVerified() bool { - return s.SigningKeyVerified +// PickleLen returns the number of bytes the pickled session will have. +func (o MegolmInboundSession) PickleLen() int { + length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm) + length += o.InitialRatchet.PickleLen() + length += o.Ratchet.PickleLen() + length += o.SigningKey.PickleLen() + length += libolmpickle.PickleBoolLen(o.SigningKeyVerified) + return length } diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 7f923534..e594258d 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -3,16 +3,17 @@ package session import ( "crypto/rand" "encoding/base64" + "errors" "fmt" - "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -26,13 +27,11 @@ type MegolmOutboundSession struct { SigningKey crypto.Ed25519KeyPair `json:"signing_key"` } -var _ olm.OutboundGroupSession = (*MegolmOutboundSession)(nil) - // NewMegolmOutboundSession creates a new MegolmOutboundSession. func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { o := &MegolmOutboundSession{} var err error - o.SigningKey, err = crypto.Ed25519GenerateKey() + o.SigningKey, err = crypto.Ed25519GenerateKey(nil) if err != nil { return nil, err } @@ -52,94 +51,121 @@ func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { // MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput) } a := &MegolmOutboundSession{} err := a.Unpickle(pickled, key) - return a, err + if err != nil { + return nil, err + } + return a, nil } // Encrypt encrypts the plaintext as a base64 encoded group message. func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { - if len(plaintext) == 0 { - return nil, olm.ErrEmptyInput + encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) + if err != nil { + return nil, err } - encrypted, err := o.Ratchet.Encrypt(plaintext, o.SigningKey) - return goolmbase64.Encode(encrypted), err + return goolm.Base64Encode(encrypted), nil } // SessionID returns the base64 endoded public signing key -func (o *MegolmOutboundSession) ID() id.SessionID { +func (o MegolmOutboundSession) SessionID() id.SessionID { return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey)) } // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) +func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) + return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return olm.ErrNoKeyProvided - } - decrypted, err := libolmpickle.Unpickle(key, pickled) + decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err } - return o.UnpickleLibOlm(decrypted) + _, err = o.UnpickleLibOlm(decrypted) + return err } -// UnpickleLibOlm unpickles the unencryted value and populates the -// [MegolmOutboundSession] accordingly. -func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { - decoder := libolmpickle.NewDecoder(buf) - pickledVersion, err := decoder.ReadUInt32() +// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. +func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) if err != nil { - return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) - } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + return 0, err } - if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { - return err + switch pickledVersion { + case megolmOutboundSessionPickleVersionLibOlm: + default: + return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) } - return o.SigningKey.UnpickleLibOlm(decoder) + readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil } // Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm(). -func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided +func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, o.PickleLen()) + written, err := o.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err } - return libolmpickle.Pickle(key, o.PickleLibOlm()) + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil } -// PickleLibOlm pickles the session returning the raw bytes. -func (o *MegolmOutboundSession) PickleLibOlm() []byte { - encoder := libolmpickle.NewEncoder() - encoder.WriteUInt32(megolmOutboundSessionPickleVersionLibOlm) - o.Ratchet.PickleLibOlm(encoder) - o.SigningKey.PickleLibOlm(encoder) - return encoder.Bytes() +// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { + if len(target) < o.PickleLen() { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) + } + written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) + writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) + } + written += writtenRatchet + writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) + } + written += writtenPubKey + return written, nil } -func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { +// PickleLen returns the number of bytes the pickled session will have. +func (o MegolmOutboundSession) PickleLen() int { + length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm) + length += o.Ratchet.PickleLen() + length += o.SigningKey.PickleLen() + return length +} + +func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { return o.Ratchet.SessionSharingMessage(o.SigningKey) } - -// MessageIndex returns the message index for this session. Each message is -// sent with an increasing index; this returns the index for the next message. -func (s *MegolmOutboundSession) MessageIndex() uint { - return uint(s.Ratchet.Counter) -} - -// Key returns the base64-encoded current ratchet key for this session. -func (s *MegolmOutboundSession) Key() string { - return string(exerrors.Must(s.SessionSharingMessage())) -} diff --git a/crypto/goolm/session/megolm_session_test.go b/crypto/goolm/session/megolm_session_test.go index 72d8857b..9b3f56b5 100644 --- a/crypto/goolm/session/megolm_session_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -1,57 +1,92 @@ package session_test import ( + "bytes" "crypto/rand" - "encoding/base64" + "errors" "testing" - "github.com/stretchr/testify/assert" - + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/olm" ) func TestOutboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess, err := session.NewMegolmOutboundSession() - assert.NoError(t, err) - kp, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } + kp, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } sess.SigningKey = kp pickled, err := sess.PickleAsJSON(pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } newSession := session.MegolmOutboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) - assert.NoError(t, err) - assert.Equal(t, sess.ID(), newSession.ID()) - assert.Equal(t, sess.SigningKey, newSession.SigningKey) - assert.Equal(t, sess.Ratchet, newSession.Ratchet) + if err != nil { + t.Fatal(err) + } + if sess.SessionID() != newSession.SessionID() { + t.Fatal("session ids not equal") + } + if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { + t.Fatal("private keys not equal") + } + if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { + t.Fatal("ratchet data not equal") + } + if sess.Ratchet.Counter != newSession.Ratchet.Counter { + t.Fatal("ratchet counter not equal") + } } func TestInboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess := session.MegolmInboundSession{} - kp, err := crypto.Ed25519GenerateKey() - assert.NoError(t, err) + kp, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } sess.SigningKey = kp.PublicKey var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte _, err = rand.Read(randomData[:]) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } ratchet, err := megolm.New(0, randomData) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } sess.Ratchet = *ratchet pickled, err := sess.PickleAsJSON(pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } newSession := session.MegolmInboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) - assert.NoError(t, err) - assert.Equal(t, sess.ID(), newSession.ID()) - assert.Equal(t, sess.SigningKey, newSession.SigningKey) - assert.Equal(t, sess.Ratchet, newSession.Ratchet) + if err != nil { + t.Fatal(err) + } + if sess.SessionID() != newSession.SessionID() { + t.Fatal("sess ids not equal") + } + if !bytes.Equal(sess.SigningKey, newSession.SigningKey) { + t.Fatal("private keys not equal") + } + if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { + t.Fatal("ratchet data not equal") + } + if sess.Ratchet.Counter != newSession.Ratchet.Counter { + t.Fatal("ratchet counter not equal") + } } func TestGroupSendReceive(t *testing.T) { @@ -65,27 +100,46 @@ func TestGroupSendReceive(t *testing.T) { ) outboundSession, err := session.NewMegolmOutboundSession() - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } copy(outboundSession.Ratchet.Data[:], randomData) - assert.EqualValues(t, 0, outboundSession.Ratchet.Counter) - + if outboundSession.Ratchet.Counter != 0 { + t.Fatal("ratchet counter is not correkt") + } sessionSharing, err := outboundSession.SessionSharingMessage() - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } plainText := []byte("Message") ciphertext, err := outboundSession.Encrypt(plainText) - assert.NoError(t, err) - assert.EqualValues(t, 1, outboundSession.Ratchet.Counter) + if err != nil { + t.Fatal(err) + } + if outboundSession.Ratchet.Counter != 1 { + t.Fatal("ratchet counter is not correkt") + } //build inbound session inboundSession, err := session.NewMegolmInboundSession(sessionSharing) - assert.NoError(t, err) - assert.True(t, inboundSession.SigningKeyVerified) - assert.Equal(t, outboundSession.ID(), inboundSession.ID()) + if err != nil { + t.Fatal(err) + } + if !inboundSession.SigningKeyVerified { + t.Fatal("key not verified") + } + if inboundSession.SessionID() != outboundSession.SessionID() { + t.Fatal("session ids not equal") + } //decode message decoded, _, err := inboundSession.Decrypt(ciphertext) - assert.NoError(t, err) - assert.Equal(t, plainText, decoded) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decoded) { + t.Fatal("messages not equal") + } } func TestGroupSessionExportImport(t *testing.T) { @@ -104,26 +158,45 @@ func TestGroupSessionExportImport(t *testing.T) { //init inbound inboundSession, err := session.NewMegolmInboundSession(sessionKey) - assert.NoError(t, err) - assert.True(t, inboundSession.SigningKeyVerified) + if err != nil { + t.Fatal(err) + } + if !inboundSession.SigningKeyVerified { + t.Fatal("signing key not verified") + } decrypted, _, err := inboundSession.Decrypt(message) - assert.NoError(t, err) - assert.Equal(t, plaintext, decrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("message is not correct") + } //Export the keys - exported, err := inboundSession.Export(0) - assert.NoError(t, err) + exported, err := inboundSession.SessionExportMessage(0) + if err != nil { + t.Fatal(err) + } secondInboundSession, err := session.NewMegolmInboundSessionFromExport(exported) - assert.NoError(t, err) - assert.False(t, secondInboundSession.SigningKeyVerified) - + if err != nil { + t.Fatal(err) + } + if secondInboundSession.SigningKeyVerified { + t.Fatal("signing key is verified") + } //decrypt with new session decrypted, _, err = secondInboundSession.Decrypt(message) - assert.NoError(t, err) - assert.Equal(t, plaintext, decrypted) - assert.True(t, secondInboundSession.SigningKeyVerified) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("message is not correct") + } + if !secondInboundSession.SigningKeyVerified { + t.Fatal("signing key not verified") + } } func TestBadSignatureGroupMessage(t *testing.T) { @@ -142,43 +215,70 @@ func TestBadSignatureGroupMessage(t *testing.T) { //init inbound inboundSession, err := session.NewMegolmInboundSession(sessionKey) - assert.NoError(t, err) - assert.True(t, inboundSession.SigningKeyVerified) + if err != nil { + t.Fatal(err) + } + if !inboundSession.SigningKeyVerified { + t.Fatal("signing key not verified") + } decrypted, _, err := inboundSession.Decrypt(message) - assert.NoError(t, err) - assert.Equal(t, plaintext, decrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("message is not correct") + } //Now twiddle the signature copy(message[len(message)-1:], []byte("E")) _, _, err = inboundSession.Decrypt(message) - assert.ErrorIs(t, err, olm.ErrBadSignature) + if err == nil { + t.Fatal("Signature was changed but did not cause an error") + } + if !errors.Is(err, goolm.ErrBadSignature) { + t.Fatalf("wrong error %s", err.Error()) + } } func TestOutbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04") pickleKey := []byte("secret_key") sess, err := session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } newPickled, err := sess.Pickle(pickleKey) - assert.NoError(t, err) - assert.Equal(t, pickledDataFromLibOlm, newPickled) - + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pickledDataFromLibOlm, newPickled) { + t.Fatal("pickled version does not equal libolm version") + } pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - assert.ErrorIs(t, err, olm.ErrBadMAC) + if err == nil { + t.Fatal("should have gotten an error") + } } func TestInbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ") pickleKey := []byte("secret_key") sess, err := session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } newPickled, err := sess.Pickle(pickleKey) - assert.NoError(t, err) - assert.Equal(t, pickledDataFromLibOlm, newPickled) - + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pickledDataFromLibOlm, newPickled) { + t.Fatal("pickled version does not equal libolm version") + } pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - assert.ErrorIs(t, err, base64.CorruptInputError(416)) + if err == nil { + t.Fatal("should have gotten an error") + } } diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index a1cb8d66..6655e0a5 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -2,17 +2,18 @@ package session import ( "bytes" - "crypto/sha256" "encoding/base64" + "errors" "fmt" - "strings" + "io" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/ratchet" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/id" ) @@ -31,11 +32,9 @@ type OlmSession struct { AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"` AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"` BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"` - Ratchet ratchet.Ratchet `json:"ratchet"` + Ratchet olm.Ratchet `json:"ratchet"` } -var _ olm.Session = (*OlmSession)(nil) - // SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key. type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey @@ -43,25 +42,33 @@ type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey // the Session using the supplied key. func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) } a := &OlmSession{} - return a, a.UnpickleAsJSON(pickled, key) + err := a.UnpickleAsJSON(pickled, key) + if err != nil { + return nil, err + } + return a, nil } // OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key. func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) } a := &OlmSession{} - return a, a.Unpickle(pickled, key) + err := a.Unpickle(pickled, key) + if err != nil { + return nil, err + } + return a, nil } // NewOlmSession creates a new Session. func NewOlmSession() *OlmSession { s := &OlmSession{} - s.Ratchet = *ratchet.New() + s.Ratchet = *olm.New() return s } @@ -70,12 +77,12 @@ func NewOlmSession() *OlmSession { func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) { s := NewOlmSession() //generate E_A - baseKey, err := crypto.Curve25519GenerateKey() + baseKey, err := crypto.Curve25519GenerateKey(nil) if err != nil { return nil, err } //generate T_0 - ratchetKey, err := crypto.Curve25519GenerateKey() + ratchetKey, err := crypto.Curve25519GenerateKey(nil) if err != nil { return nil, err } @@ -110,7 +117,7 @@ func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKe // NewInboundOlmSession creates a new inbound session from receiving the first message. func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) { - decodedOTKMsg, err := goolmbase64.Decode(receivedOTKMsg) + decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) if err != nil { return nil, err } @@ -123,7 +130,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err) } if !oneTimeMsg.CheckFields(identityKeyAlice) { - return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat) } //Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked @@ -131,7 +138,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 { //if both are set, compare them if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) { - return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", olm.ErrBadMessageKeyID) + return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID) } } if identityKeyAlice == nil { @@ -141,7 +148,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey) if oneTimeKeyBob == nil { - return nil, fmt.Errorf("ourOneTimeKey: %w", olm.ErrBadMessageKeyID) + return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID) } //Calculate shared secret via Triple Diffie-Hellman @@ -168,11 +175,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("message decode: %w", err) + return nil, fmt.Errorf("Message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -187,64 +194,40 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) { - return libolmpickle.PickleAsJSON(a, olmSessionPickleVersionJSON, key) + return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { - return libolmpickle.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) + return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) } // ID returns an identifier for this Session. Will be the same for both ends of the conversation. // Generated by hashing the public keys used to create the session. -func (s *OlmSession) ID() id.SessionID { - message := make([]byte, 3*crypto.Curve25519PrivateKeyLength) +func (s OlmSession) ID() id.SessionID { + message := make([]byte, 3*crypto.Curve25519KeyLength) copy(message, s.AliceIdentityKey) - copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) - copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) - hash := sha256.Sum256(message) - res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) + copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey) + copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey) + hash := crypto.SHA256(message) + res := id.SessionID(goolm.Base64Encode(hash)) return res } // HasReceivedMessage returns true if this session has received any message. -func (s *OlmSession) HasReceivedMessage() bool { +func (s OlmSession) HasReceivedMessage() bool { return s.ReceivedMessage } -// MatchesInboundSession checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. -func (s *OlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { - return s.matchesInboundSession(nil, []byte(oneTimeKeyMsg)) -} - -// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. -func (s *OlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { - var theirKey *id.Curve25519 - if theirIdentityKey != "" { - theirs := id.Curve25519(theirIdentityKey) - theirKey = &theirs - } - - return s.matchesInboundSession(theirKey, []byte(oneTimeKeyMsg)) -} - -// matchesInboundSession checks if the oneTimeKeyMsg message is set for this -// inbound Session. This can happen if multiple messages are sent to this -// Account before this Account sends a message in reply. Returns true if the -// session matches. Returns false if the session does not match. -func (s *OlmSession) matchesInboundSession(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { +// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. +func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { if len(receivedOTKMsg) == 0 { - return false, fmt.Errorf("inbound match: %w", olm.ErrEmptyInput) + return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput) } - decodedOTKMsg, err := goolmbase64.Decode(receivedOTKMsg) + decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) if err != nil { return false, err } @@ -283,20 +266,20 @@ func (s *OlmSession) matchesInboundSession(theirIdentityKeyEncoded *id.Curve2551 // EncryptMsgType returns the type of the next message that Encrypt will // return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg. // Returns MsgTypeMsg if the message will be a normal message. -func (s *OlmSession) EncryptMsgType() id.OlmMsgType { +func (s OlmSession) EncryptMsgType() id.OlmMsgType { if s.ReceivedMessage { return id.OlmMsgTypeMsg } return id.OlmMsgTypePreKey } -// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. -func (s *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { +// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations. +func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, fmt.Errorf("encrypt: %w", olm.ErrEmptyInput) + return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput) } messageType := s.EncryptMsgType() - encrypted, err := s.Ratchet.Encrypt(plaintext) + encrypted, err := s.Ratchet.Encrypt(plaintext, reader) if err != nil { return 0, nil, err } @@ -317,15 +300,15 @@ func (s *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { result = messageBody } - return messageType, goolmbase64.Encode(result), nil + return messageType, goolm.Base64Encode(result), nil } // Decrypt decrypts a base64 encoded message using the Session. -func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, error) { +func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) { if len(crypttext) == 0 { - return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) + return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput) } - decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) + decodedCrypttext, err := goolm.Base64Decode(crypttext) if err != nil { return nil, err } @@ -350,80 +333,144 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *OlmSession) Unpickle(pickled, key []byte) error { - if len(pickled) == 0 { - return olm.ErrEmptyInput - } - decrypted, err := libolmpickle.Unpickle(key, pickled) + decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err } - return o.UnpickleLibOlm(decrypted) + _, err = o.UnpickleLibOlm(decrypted) + return err } -// UnpickleLibOlm unpickles the unencryted value and populates the [OlmSession] -// accordingly. -func (o *OlmSession) UnpickleLibOlm(buf []byte) error { - decoder := libolmpickle.NewDecoder(buf) - pickledVersion, err := decoder.ReadUInt32() +// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. +func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) if err != nil { - return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) + return 0, err } - - var includesChainIndex bool + includesChainIndex := true switch pickledVersion { case olmSessionPickleVersionLibOlm: includesChainIndex = false case uint32(0x80000001): includesChainIndex = true default: - return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) } - - if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { - return err - } else if err = o.AliceIdentityKey.UnpickleLibOlm(decoder); err != nil { - return err - } else if err = o.AliceBaseKey.UnpickleLibOlm(decoder); err != nil { - return err - } else if err = o.BobOneTimeKey.UnpickleLibOlm(decoder); err != nil { - return err + var readBytes int + o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) + if err != nil { + return 0, err } - return o.Ratchet.UnpickleLibOlm(decoder, includesChainIndex) + curPos += readBytes + readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil } -// Pickle returns a base64 encoded and with key encrypted pickled olmSession -// using PickleLibOlm(). -func (s *OlmSession) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided +// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm(). +func (o OlmSession) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, o.PickleLen()) + written, err := o.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err } - return libolmpickle.Pickle(key, s.PickleLibOlm()) + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil } -// PickleLibOlm pickles the session and returns the raw bytes. -func (o *OlmSession) PickleLibOlm() []byte { - encoder := libolmpickle.NewEncoder() - encoder.WriteUInt32(olmSessionPickleVersionLibOlm) - encoder.WriteBool(o.ReceivedMessage) - o.AliceIdentityKey.PickleLibOlm(encoder) - o.AliceBaseKey.PickleLibOlm(encoder) - o.BobOneTimeKey.PickleLibOlm(encoder) - o.Ratchet.PickleLibOlm(encoder) - return encoder.Bytes() +// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (o OlmSession) PickleLibOlm(target []byte) (int, error) { + if len(target) < o.PickleLen() { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) + } + written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) + written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:]) + writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) + } + written += writtenRatchet + writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) + } + written += writtenRatchet + writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) + } + written += writtenRatchet + writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) + } + written += writtenRatchet + return written, nil +} + +// PickleLen returns the actual number of bytes the pickled session will have. +func (o OlmSession) PickleLen() int { + length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) + length += libolmpickle.PickleBoolLen(o.ReceivedMessage) + length += o.AliceIdentityKey.PickleLen() + length += o.AliceBaseKey.PickleLen() + length += o.BobOneTimeKey.PickleLen() + length += o.Ratchet.PickleLen() + return length +} + +// PickleLenMin returns the minimum number of bytes the pickled session must have. +func (o OlmSession) PickleLenMin() int { + length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) + length += libolmpickle.PickleBoolLen(o.ReceivedMessage) + length += o.AliceIdentityKey.PickleLen() + length += o.AliceBaseKey.PickleLen() + length += o.BobOneTimeKey.PickleLen() + length += o.Ratchet.PickleLenMin() + return length } // Describe returns a string describing the current state of the session for debugging. -func (o *OlmSession) Describe() string { - var builder strings.Builder - builder.WriteString("sender chain index: ") - builder.WriteString(fmt.Sprint(o.Ratchet.SenderChains.CKey.Index)) - builder.WriteString(" receiver chain indices:") +func (o OlmSession) Describe() string { + var res string + if o.Ratchet.SenderChains.IsSet { + res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index) + } else { + res += "sender chain index: " + } + res += "receiver chain indicies:" for _, curChain := range o.Ratchet.ReceiverChains { - builder.WriteString(fmt.Sprintf(" %d", curChain.CKey.Index)) + res += fmt.Sprintf(" %d", curChain.CKey.Index) } - builder.WriteString(" skipped message keys:") + res += " skipped message keys:" for _, curSkip := range o.Ratchet.SkippedMessageKeys { - builder.WriteString(fmt.Sprintf(" %d", curSkip.MKey.Index)) + res += fmt.Sprintf(" %d", curSkip.MKey.Index) } - return builder.String() + return res } diff --git a/crypto/goolm/session/olm_session_test.go b/crypto/goolm/session/olm_session_test.go index f87c2e7e..11b13c32 100644 --- a/crypto/goolm/session/olm_session_test.go +++ b/crypto/goolm/session/olm_session_test.go @@ -1,32 +1,44 @@ package session_test import ( + "bytes" "encoding/base64" + "errors" "testing" - "github.com/stretchr/testify/assert" - + "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) func TestOlmSession(t *testing.T) { pickleKey := []byte("secretKey") - aliceKeyPair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) - bobKeyPair, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) - bobOneTimeKey, err := crypto.Curve25519GenerateKey() - assert.NoError(t, err) + aliceKeyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + bobKeyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + bobOneTimeKey, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } aliceSession, err := session.NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } //create a message so that there are more keys to marshal plaintext := []byte("Test message from Alice to Bob") - msgType, message, err := aliceSession.Encrypt(plaintext) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType) + msgType, message, err := aliceSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("Wrong message type") + } searchFunc := func(target crypto.Curve25519PublicKey) *crypto.OneTimeKey { if target.Equal(bobOneTimeKey.PublicKey) { @@ -40,58 +52,92 @@ func TestOlmSession(t *testing.T) { } //bob receives message bobSession, err := session.NewInboundOlmSession(nil, message, searchFunc, bobKeyPair) - assert.NoError(t, err) - decryptedMsg, err := bobSession.Decrypt(string(message), msgType) - assert.NoError(t, err) - assert.Equal(t, plaintext, decryptedMsg) + if err != nil { + t.Fatal(err) + } + decryptedMsg, err := bobSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decryptedMsg) { + t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) + } // Alice pickles session pickled, err := aliceSession.PickleAsJSON(pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } //bob sends a message plaintext = []byte("A message from Bob to Alice") - msgType, message, err = bobSession.Encrypt(plaintext) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypeMsg, msgType) + msgType, message, err = bobSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypeMsg { + t.Fatal("Wrong message type") + } //Alice unpickles session newAliceSession, err := session.OlmSessionFromJSONPickled(pickled, pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } //Alice receives message - decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType) - assert.NoError(t, err) - assert.Equal(t, plaintext, decryptedMsg) + decryptedMsg, err = newAliceSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decryptedMsg) { + t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) + } //Alice receives message again - _, err = newAliceSession.Decrypt(string(message), msgType) - assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound) + _, err = newAliceSession.Decrypt(message, msgType) + if err == nil { + t.Fatal("should have gotten an error") + } //Alice sends another message plaintext = []byte("A second message to Bob") - msgType, message, err = newAliceSession.Encrypt(plaintext) - assert.NoError(t, err) - assert.Equal(t, id.OlmMsgTypeMsg, msgType) - + msgType, message, err = newAliceSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypeMsg { + t.Fatal("Wrong message type") + } //bob receives message - decryptedMsg, err = bobSession.Decrypt(string(message), msgType) - assert.NoError(t, err) - assert.Equal(t, plaintext, decryptedMsg) + decryptedMsg, err = bobSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decryptedMsg) { + t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) + } } func TestSessionPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") pickleKey := []byte("secret_key") sess, err := session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } newPickled, err := sess.Pickle(pickleKey) - assert.NoError(t, err) - assert.Equal(t, pickledDataFromLibOlm, newPickled) - + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pickledDataFromLibOlm, newPickled) { + t.Fatal("pickled version does not equal libolm version") + } pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) - assert.ErrorIs(t, err, base64.CorruptInputError(224)) + if err == nil { + t.Fatal("should have gotten an error") + } } func TestDecrypts(t *testing.T) { @@ -102,7 +148,7 @@ func TestDecrypts(t *testing.T) { {0xe9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xc1}, } expectedErr := []error{ - olm.ErrInputToSmall, + goolm.ErrInputToSmall, // Why are these being tested 🤔 base64.CorruptInputError(0), base64.CorruptInputError(0), @@ -115,9 +161,17 @@ func TestDecrypts(t *testing.T) { "dGvPXeH8qLeNZA") pickleKey := []byte("") sess, err := session.OlmSessionFromPickled(sessionPickled, pickleKey) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } for curIndex, curMessage := range messages { - _, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey) - assert.ErrorIs(t, err, expectedErr[curIndex]) + _, err := sess.Decrypt(curMessage, id.OlmMsgTypePreKey) + if err != nil { + if !errors.Is(err, expectedErr[curIndex]) { + t.Fatal(err) + } + } else { + t.Fatal("error expected") + } } } diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go deleted file mode 100644 index b95a44ac..00000000 --- a/crypto/goolm/session/register.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package session - -import ( - "maunium.net/go/mautrix/crypto/olm" -) - -func Register() { - // Inbound Session - olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - if len(key) == 0 { - key = []byte(" ") - } - return MegolmInboundSessionFromPickled(pickled, key) - } - olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput - } - return NewMegolmInboundSession(sessionKey) - } - olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput - } - return NewMegolmInboundSessionFromExport(sessionKey) - } - olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { - return &MegolmInboundSession{} - } - - // Outbound Session - olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - return MegolmOutboundSessionFromPickled(pickled, key) - } - olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewMegolmOutboundSession() } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return &MegolmOutboundSession{} } - - // Olm Session - olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { - return OlmSessionFromPickled(pickled, key) - } - olm.InitNewBlankSession = func() olm.Session { - return NewOlmSession() - } -} diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/utilities/pickle.go similarity index 72% rename from crypto/goolm/libolmpickle/picklejson.go rename to crypto/goolm/utilities/pickle.go index f765391f..993366c8 100644 --- a/crypto/goolm/libolmpickle/picklejson.go +++ b/crypto/goolm/utilities/pickle.go @@ -1,17 +1,17 @@ -package libolmpickle +package utilities import ( - "crypto/aes" "encoding/json" "fmt" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" ) // PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format. func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { if len(key) == 0 { - return nil, fmt.Errorf("pickle: %w", olm.ErrNoKeyProvided) + return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided) } marshaled, err := json.Marshal(object) if err != nil { @@ -21,12 +21,12 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { toEncrypt := make([]byte, len(marshaled)) copy(toEncrypt, marshaled) //pad marshaled to get block size - if len(marshaled)%aes.BlockSize != 0 { - padding := aes.BlockSize - len(marshaled)%aes.BlockSize + if len(marshaled)%cipher.PickleBlockSize() != 0 { + padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize() toEncrypt = make([]byte, len(marshaled)+padding) copy(toEncrypt, marshaled) } - encrypted, err := Pickle(key, toEncrypt) + encrypted, err := cipher.Pickle(key, toEncrypt) if err != nil { return nil, fmt.Errorf("pickle encrypt: %w", err) } @@ -36,9 +36,9 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { // UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { if len(key) == 0 { - return fmt.Errorf("unpickle: %w", olm.ErrNoKeyProvided) + return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided) } - decrypted, err := Unpickle(key, pickled) + decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return fmt.Errorf("unpickle decrypt: %w", err) } @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) + return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 7b3c30db..3e65f4c1 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -2,8 +2,6 @@ package crypto import ( "context" - "encoding/base64" - "errors" "fmt" "time" @@ -13,7 +11,6 @@ import ( "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -24,7 +21,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg ctx = log.WithContext(ctx) - versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx, megolmBackupKey) + versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) if err != nil { return "", err } else if versionInfo == nil { @@ -35,7 +32,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return versionInfo.Version, err } -func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { +func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) if err != nil { return nil, err @@ -51,24 +48,6 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, Stringer("key_backup_version", versionInfo.Version). Logger() - // https://spec.matrix.org/v1.10/client-server-api/#server-side-key-backups - // "Clients must only store keys in backups after they have ensured that the auth_data is trusted. This can be done either... - // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private - // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing - // from a verified device belonging to the same user." - if megolmBackupKey != nil { - megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) - if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { - log.Debug().Msg("Key backup is trusted based on derived public key") - return versionInfo, nil - } - log.Debug(). - Stringer("expected_key", megolmBackupDerivedPublicKey). - Stringer("actual_key", versionInfo.AuthData.PublicKey). - Msg("key backup public keys do not match, proceeding to check device signatures") - } - - // "...or checking that it is signed by the user’s master cross-signing key or by a verified device belonging to the same user" userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] if !ok { return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) @@ -95,7 +74,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, } else if device == nil { log.Warn().Err(err).Msg("Device does not exist, ignoring signature") continue - } else if !mach.IsDeviceTrusted(ctx, device) { + } else if !mach.IsDeviceTrusted(device) { log.Warn().Err(err).Msg("Device is not trusted") continue } else { @@ -108,7 +87,6 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, continue } else { // One of the signatures is valid, break from the loop. - log.Debug().Stringer("key_id", keyID).Msg("key backup is trusted based on matching signature") signatureVerified = true break } @@ -157,23 +135,13 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key return nil } -var ( - ErrUnknownAlgorithmInKeyBackup = errors.New("ignoring room key in backup with weird algorithm") - ErrMismatchingSessionIDInKeyBackup = errors.New("mismatched session ID while creating inbound group session from key backup") - ErrFailedToStoreNewInboundGroupSessionFromBackup = errors.New("failed to store new inbound group session from key backup") -) - -func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( - ctx context.Context, - version id.KeyBackupVersion, - roomID id.RoomID, - config *event.EncryptionEventContent, - sessionID id.SessionID, - keyBackupData *backup.MegolmSessionData, -) (*InboundGroupSession, error) { - log := zerolog.Ctx(ctx) +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { + log := zerolog.Ctx(ctx).With(). + Str("room_id", roomID.String()). + Str("session_id", sessionID.String()). + Logger() if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { - return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm) + return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) } igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) @@ -181,60 +149,42 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( return nil, fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { log.Warn(). - Stringer("room_id", roomID). - Stringer("session_id", sessionID). Stringer("actual_session_id", igsInternal.ID()). Msg("Mismatched session ID while creating inbound group session from key backup") - return nil, ErrMismatchingSessionIDInKeyBackup + return nil, fmt.Errorf("mismatched session ID while creating inbound group session from key backup") } var maxAge time.Duration var maxMessages int - if config != nil { + if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil { + log.Error().Err(err).Msg("Failed to get encryption event for room") + } else if config != nil { maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond maxMessages = config.RotationPeriodMessages } - return &InboundGroupSession{ - Internal: igsInternal, + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { + log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") + } + + igs := &InboundGroupSession{ + Internal: *igsInternal, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, - ForwardingChains: keyBackupData.ForwardingKeyChain, + ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), id: sessionID, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, - KeySource: id.KeySourceBackup, - }, nil -} - -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { - config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("room_id", roomID). - Stringer("session_id", sessionID). - Msg("Failed to get encryption event for room") - } - imported, err := mach.ImportRoomKeyFromBackupWithoutSaving(ctx, version, roomID, config, sessionID, keyBackupData) - if err != nil { - return nil, err - } - firstKnownIndex := imported.Internal.FirstKnownIndex() - if firstKnownIndex > 0 { - zerolog.Ctx(ctx).Warn(). - Stringer("room_id", roomID). - Stringer("session_id", sessionID). - Uint32("first_known_index", firstKnownIndex). - Msg("Importing partial session") - } - err = mach.CryptoStore.PutGroupSession(ctx, imported) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err) - } - mach.MarkSessionReceived(ctx, roomID, sessionID, firstKnownIndex) - return imported, nil + } + err = mach.CryptoStore.PutGroupSession(ctx, igs) + if err != nil { + return nil, fmt.Errorf("failed to store new inbound group session: %w", err) + } + mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) + return igs, nil } diff --git a/crypto/keyexport.go b/crypto/keyexport.go index 1904c8a5..3d126db4 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,21 +16,15 @@ import ( "encoding/base64" "encoding/binary" "encoding/json" - "errors" "fmt" "math" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exbytes" - "go.mau.fi/util/exerrors" "go.mau.fi/util/random" "golang.org/x/crypto/pbkdf2" "maunium.net/go/mautrix/id" ) -var ErrNoSessionsForExport = errors.New("no sessions provided for export") - type SenderClaimedKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } @@ -84,14 +78,22 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) return } -func exportSessions(sessions []*InboundGroupSession) ([]*ExportedSession, error) { - export := make([]*ExportedSession, len(sessions)) - var err error +func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) { + export := make([]ExportedSession, len(sessions)) for i, session := range sessions { - export[i], err = session.export() + key, err := session.Internal.Export(session.Internal.FirstKnownIndex()) if err != nil { return nil, fmt.Errorf("failed to export session: %w", err) } + export[i] = ExportedSession{ + Algorithm: id.AlgorithmMegolmV1, + ForwardingChains: session.ForwardingChains, + RoomID: session.RoomID, + SenderKey: session.SenderKey, + SenderClaimedKeys: SenderClaimedKeys{}, + SessionID: session.ID(), + SessionKey: string(key), + } } return export, nil } @@ -105,73 +107,38 @@ func exportSessionsJSON(sessions []*InboundGroupSession) ([]byte, error) { } func formatKeyExportData(data []byte) []byte { - encodedLen := base64.StdEncoding.EncodedLen(len(data)) - outputLength := len(exportPrefix) + - encodedLen + int(math.Ceil(float64(encodedLen)/exportLineLengthLimit)) + - len(exportSuffix) - output := make([]byte, 0, outputLength) - outputWriter := (*exbytes.Writer)(&output) - base64Writer := base64.NewEncoder(base64.StdEncoding, outputWriter) - lineByteCount := base64.StdEncoding.DecodedLen(exportLineLengthLimit) - exerrors.Must(outputWriter.WriteString(exportPrefix)) - for i := 0; i < len(data); i += lineByteCount { - exerrors.Must(base64Writer.Write(data[i:min(i+lineByteCount, len(data))])) - if i+lineByteCount >= len(data) { - exerrors.PanicIfNotNil(base64Writer.Close()) - } - exerrors.PanicIfNotNil(outputWriter.WriteByte('\n')) - } - exerrors.Must(outputWriter.WriteString(exportSuffix)) - if len(output) != outputLength { - panic(fmt.Errorf("unexpected length %d / %d", len(output), outputLength)) - } - return output -} + base64Data := make([]byte, base64.StdEncoding.EncodedLen(len(data))) + base64.StdEncoding.Encode(base64Data, data) -func ExportKeysIter(passphrase string, sessions dbutil.RowIter[*InboundGroupSession]) ([]byte, error) { - buf := bytes.NewBuffer(make([]byte, 0, 50*1024)) - enc := json.NewEncoder(buf) - buf.WriteByte('[') - err := sessions.Iter(func(session *InboundGroupSession) (bool, error) { - exported, err := session.export() - if err != nil { - return false, err - } - err = enc.Encode(exported) - if err != nil { - return false, err - } - buf.WriteByte(',') - return true, nil - }) - if err != nil { - return nil, err + // Prefix + data and newline for each 76 characters of data + suffix + outputLength := len(exportPrefix) + + len(base64Data) + int(math.Ceil(float64(len(base64Data))/exportLineLengthLimit)) + + len(exportSuffix) + + var buf bytes.Buffer + buf.Grow(outputLength) + buf.WriteString(exportPrefix) + for ptr := 0; ptr < len(base64Data); ptr += exportLineLengthLimit { + buf.Write(base64Data[ptr:min(ptr+exportLineLengthLimit, len(base64Data))]) + buf.WriteRune('\n') } - output := buf.Bytes() - if len(output) == 1 { - return nil, ErrNoSessionsForExport + buf.WriteString(exportSuffix) + if buf.Len() != buf.Cap() || buf.Len() != outputLength { + panic(fmt.Errorf("unexpected length %d / %d / %d", buf.Len(), buf.Cap(), outputLength)) } - output[len(output)-1] = ']' // Replace the last comma with a closing bracket - return EncryptKeyExport(passphrase, output) + return buf.Bytes() } // ExportKeys exports the given Megolm sessions with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports func ExportKeys(passphrase string, sessions []*InboundGroupSession) ([]byte, error) { - if len(sessions) == 0 { - return nil, ErrNoSessionsForExport - } + // Make all the keys necessary for exporting + encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) // Export all the given sessions and put them in JSON unencryptedData, err := exportSessionsJSON(sessions) if err != nil { return nil, err } - return EncryptKeyExport(passphrase, unencryptedData) -} - -func EncryptKeyExport(passphrase string, unencryptedData json.RawMessage) ([]byte, error) { - // Make all the keys necessary for exporting - encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) // The export data consists of: // 1 byte of export format version diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go deleted file mode 100644 index fd6f105d..00000000 --- a/crypto/keyexport_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package crypto_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exfmt" - - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/olm" -) - -func TestExportKeys(t *testing.T) { - acc := crypto.NewOlmAccount() - sess := exerrors.Must(crypto.NewInboundGroupSession( - acc.IdentityKey(), - acc.SigningKey(), - "!room:example.com", - exerrors.Must(olm.NewOutboundGroupSession()).Key(), - 7*exfmt.Day, - 100, - false, - )) - data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) - assert.NoError(t, err) - assert.Len(t, data, 893) -} diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 3ffc74a5..693ff6b8 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -36,10 +36,6 @@ var ( var exportPrefixBytes, exportSuffixBytes = []byte(exportPrefix), []byte(exportSuffix) func decodeKeyExport(data []byte) ([]byte, error) { - // Fix some types of corruption in the key export file before checking anything - if bytes.IndexByte(data, '\r') != -1 { - data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'}) - } // If the valid prefix and suffix aren't there, it's probably not a Matrix key export if !bytes.HasPrefix(data, exportPrefixBytes) { return nil, ErrMissingExportPrefix @@ -108,27 +104,26 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, + Internal: *igsInternal, + SigningKey: session.SenderClaimedKeys.Ed25519, + SenderKey: session.SenderKey, + RoomID: session.RoomID, + // TODO should we add something here to mark the signing key as unverified like key requests do? ForwardingChains: session.ForwardingChains, - KeySource: id.KeySourceImport, - ReceivedAt: time.Now().UTC(), + + ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) firstKnownIndex := igs.Internal.FirstKnownIndex() if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { - // We already have an equivalent or better session in the store, so don't override it, - // but do notify the session received callback just in case. - mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex()) + // We already have an equivalent or better session in the store, so don't override it. return false, nil } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) + mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 19a68c87..4d3b6f7e 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -59,15 +59,11 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to select { case <-keyResponseReceived: // key request successful - mach.Log.Debug(). - Stringer("session_id", sessionID). - Msg("Key for session was received, cancelling other key requests") + mach.Log.Debug().Msgf("Key for session %v was received, cancelling other key requests", sessionID) resChan <- true case <-ctx.Done(): // if the context is done, key request was unsuccessful - mach.Log.Debug().Err(err). - Stringer("session_id", sessionID). - Msg("Context closed before forwarded key for session received, sending key request cancellation") + mach.Log.Debug().Msgf("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID) resChan <- false } @@ -178,7 +174,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } igs := &InboundGroupSession{ - Internal: igsInternal, + Internal: *igsInternal, SigningKey: evt.Keys.Ed25519, SenderKey: content.SenderKey, RoomID: content.RoomID, @@ -189,7 +185,6 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: content.IsScheduled, - KeySource: id.KeySourceForward, } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { @@ -201,7 +196,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.MarkSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) + mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) log.Debug().Msg("Received forwarded inbound group session") return true } @@ -215,7 +210,6 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, - //lint:ignore SA1019 This is just echoing back the deprecated field SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, @@ -265,14 +259,9 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { - igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID) - if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey { - log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectNotRecipient - } - // Note: this case will also happen for redacted sessions and database errors - log.Debug().Msg("Rejecting key request for session created by another device") - return &KeyShareRejectNoResponse + // TODO differentiate session not shared with requester vs session not created by this device? + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectNotRecipient } log.Debug().Msg("Accepting key request for shared session") return nil @@ -282,7 +271,7 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev } else if device.Trust == id.TrustStateBlacklisted { log.Debug().Msg("Rejecting key request from blacklisted device") return &KeyShareRejectBlacklisted - } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState >= mach.ShareKeysMinTrust { + } else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). @@ -330,9 +319,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - if sender != mach.Client.UserID { - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) - } + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) } else { log.Error().Err(err).Msg("Failed to get group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) @@ -340,14 +327,12 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - if sender != mach.Client.UserID { - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) - } + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) return } if internalID := igs.ID(); internalID != content.Body.SessionID { // Should this be an error? - log = log.With().Stringer("unexpected_session_id", internalID).Logger() + log = log.With().Str("unexpected_session_id", internalID.String()).Logger() } firstKnownIndex := igs.Internal.FirstKnownIndex() @@ -358,6 +343,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } + if igs.ForwardingChains == nil { + igs.ForwardingChains = []string{} + } forwardedRoomKey := event.Content{ Parsed: &event.ForwardedRoomKeyEventContent{ @@ -367,7 +355,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionID: igs.ID(), SessionKey: string(exportedKey), }, - SenderKey: igs.SenderKey, + SenderKey: content.Body.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go deleted file mode 100644 index 0350f083..00000000 --- a/crypto/libolm/account.go +++ /dev/null @@ -1,419 +0,0 @@ -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - -import ( - "crypto/rand" - "encoding/base64" - "encoding/json" - "runtime" - "unsafe" - - "github.com/tidwall/gjson" - - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/id" -) - -// Account stores a device account for end to end encrypted messaging. -type Account struct { - int *C.OlmAccount - mem []byte -} - -// Ensure that [Account] implements [olm.Account]. -var _ olm.Account = (*Account)(nil) - -// AccountFromPickled loads an Account from a pickled base64 string. Decrypts -// the Account using the supplied key. Returns error on failure. If the key -// doesn't match the one used to encrypt the Account then the error will be -// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". -func AccountFromPickled(pickled, key []byte) (*Account, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - a := NewBlankAccount() - return a, a.Unpickle(pickled, key) -} - -func NewBlankAccount() *Account { - memory := make([]byte, accountSize()) - return &Account{ - int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))), - mem: memory, - } -} - -// NewAccount creates a new [Account]. -func NewAccount() (*Account, error) { - a := NewBlankAccount() - random := make([]byte, a.createRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(olm.ErrNotEnoughGoRandom) - } - ret := C.olm_create_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(random)), - C.size_t(len(random))) - runtime.KeepAlive(random) - if ret == errorVal() { - return nil, a.lastError() - } else { - return a, nil - } -} - -// accountSize returns the size of an account object in bytes. -func accountSize() uint { - return uint(C.olm_account_size()) -} - -// lastError returns an error describing the most recent error to happen to an -// account. -func (a *Account) lastError() error { - return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) -} - -// Clear clears the memory used to back this Account. -func (a *Account) Clear() error { - r := C.olm_clear_account((*C.OlmAccount)(a.int)) - if r == errorVal() { - return a.lastError() - } else { - return nil - } -} - -// pickleLen returns the number of bytes needed to store an Account. -func (a *Account) pickleLen() uint { - return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) -} - -// createRandomLen returns the number of random bytes needed to create an -// Account. -func (a *Account) createRandomLen() uint { - return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) -} - -// identityKeysLen returns the size of the output buffer needed to hold the -// identity keys. -func (a *Account) identityKeysLen() uint { - return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) -} - -// signatureLen returns the length of an ed25519 signature encoded as base64. -func (a *Account) signatureLen() uint { - return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) -} - -// oneTimeKeysLen returns the size of the output buffer needed to hold the one -// time keys. -func (a *Account) oneTimeKeysLen() uint { - return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) -} - -// genOneTimeKeysRandomLen returns the number of random bytes needed to -// generate a given number of new one time keys. -func (a *Account) genOneTimeKeysRandomLen(num uint) uint { - return uint(C.olm_account_generate_one_time_keys_random_length( - (*C.OlmAccount)(a.int), - C.size_t(num))) -} - -// Pickle returns an Account as a base64 string. Encrypts the Account using the -// supplied key. -func (a *Account) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided - } - pickled := make([]byte, a.pickleLen()) - r := C.olm_pickle_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled))) - if r == errorVal() { - return nil, a.lastError() - } - return pickled[:r], nil -} - -func (a *Account) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return olm.ErrNoKeyProvided - } - r := C.olm_unpickle_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled))) - if r == errorVal() { - return a.lastError() - } - return nil -} - -// Deprecated -func (a *Account) GobEncode() ([]byte, error) { - pickled, err := a.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (a *Account) GobDecode(rawPickled []byte) error { - if a.int == nil { - *a = *NewBlankAccount() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return a.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (a *Account) MarshalJSON() ([]byte, error) { - pickled, err := a.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (a *Account) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString - } - if a.int == nil { - *a = *NewBlankAccount() - } - return a.Unpickle(data[1:len(data)-1], pickleKey) -} - -// IdentityKeysJSON returns the public parts of the identity keys for the Account. -func (a *Account) IdentityKeysJSON() ([]byte, error) { - identityKeys := make([]byte, a.identityKeysLen()) - r := C.olm_account_identity_keys( - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(identityKeys)), - C.size_t(len(identityKeys))) - if r == errorVal() { - return nil, a.lastError() - } else { - return identityKeys, nil - } -} - -// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity -// keys for the Account. -func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { - identityKeysJSON, err := a.IdentityKeysJSON() - if err != nil { - return "", "", err - } - results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") - return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil -} - -// Sign returns the signature of a message using the ed25519 key for this -// Account. -func (a *Account) Sign(message []byte) ([]byte, error) { - if len(message) == 0 { - panic(olm.ErrEmptyInput) - } - signature := make([]byte, a.signatureLen()) - r := C.olm_account_sign( - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(message)), - C.size_t(len(message)), - unsafe.Pointer(unsafe.SliceData(signature)), - C.size_t(len(signature))) - runtime.KeepAlive(message) - if r == errorVal() { - panic(a.lastError()) - } - return signature, nil -} - -// OneTimeKeys returns the public parts of the unpublished one time keys for -// the Account. -// -// The returned data is a struct with the single value "Curve25519", which is -// itself an object mapping key id to base64-encoded Curve25519 key. For -// example: -// -// { -// Curve25519: { -// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", -// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" -// } -// } -func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { - oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) - r := C.olm_account_one_time_keys( - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)), - C.size_t(len(oneTimeKeysJSON)), - ) - if r == errorVal() { - return nil, a.lastError() - } - var oneTimeKeys struct { - Curve25519 map[string]id.Curve25519 `json:"curve25519"` - } - return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) -} - -// MarkKeysAsPublished marks the current set of one time keys as being -// published. -func (a *Account) MarkKeysAsPublished() { - C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) -} - -// MaxNumberOfOneTimeKeys returns the largest number of one time keys this -// Account can store. -func (a *Account) MaxNumberOfOneTimeKeys() uint { - return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) -} - -// GenOneTimeKeys generates a number of new one time keys. If the total number -// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old -// keys are discarded. -func (a *Account) GenOneTimeKeys(num uint) error { - random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) - _, err := rand.Read(random) - if err != nil { - return olm.ErrNotEnoughGoRandom - } - r := C.olm_account_generate_one_time_keys( - (*C.OlmAccount)(a.int), - C.size_t(num), - unsafe.Pointer(unsafe.SliceData(random)), - C.size_t(len(random)), - ) - runtime.KeepAlive(random) - if r == errorVal() { - return a.lastError() - } - return nil -} - -// NewOutboundSession creates a new out-bound session for sending messages to a -// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the -// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" -func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { - if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankSession() - random := make([]byte, s.createOutboundRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(olm.ErrNotEnoughGoRandom) - } - theirIdentityKeyCopy := []byte(theirIdentityKey) - theirOneTimeKeyCopy := []byte(theirOneTimeKey) - r := C.olm_create_outbound_session( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), - C.size_t(len(theirIdentityKeyCopy)), - unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)), - C.size_t(len(theirOneTimeKeyCopy)), - unsafe.Pointer(unsafe.SliceData(random)), - C.size_t(len(random)), - ) - runtime.KeepAlive(random) - runtime.KeepAlive(theirIdentityKeyCopy) - runtime.KeepAlive(theirOneTimeKeyCopy) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// NewInboundSession creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. If -// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If -// the message was for an unsupported protocol version then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the -// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one -// time key then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { - if len(oneTimeKeyMsg) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankSession() - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) - r := C.olm_create_inbound_session( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(oneTimeKeyMsgCopy) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// NewInboundSessionFrom creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. If -// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If -// the message was for an unsupported protocol version then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the -// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one -// time key then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { - if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { - return nil, olm.ErrEmptyInput - } - theirIdentityKeyCopy := []byte(*theirIdentityKey) - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) - s := NewBlankSession() - r := C.olm_create_inbound_session_from( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), - C.size_t(len(theirIdentityKeyCopy)), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(theirIdentityKeyCopy) - runtime.KeepAlive(oneTimeKeyMsgCopy) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// RemoveOneTimeKeys removes the one time keys that the session used from the -// Account. Returns error on failure. If the Account doesn't have any -// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) RemoveOneTimeKeys(s olm.Session) error { - r := C.olm_remove_one_time_keys( - (*C.OlmAccount)(a.int), - (*C.OlmSession)(s.(*Session).int), - ) - if r == errorVal() { - return a.lastError() - } - return nil -} diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go deleted file mode 100644 index 6fb5512b..00000000 --- a/crypto/libolm/error.go +++ /dev/null @@ -1,37 +0,0 @@ -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - -import ( - "fmt" - - "maunium.net/go/mautrix/crypto/olm" -) - -var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, - "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, - "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, - "BAD_MESSAGE_MAC": olm.ErrBadMAC, - "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, - "INVALID_BASE64": olm.ErrLibolmInvalidBase64, - "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, - "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, - "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, - "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, - "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, - "BAD_SIGNATURE": olm.ErrBadSignature, - "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, -} - -func convertError(errCode string) error { - err, ok := errorMap[errCode] - if ok { - return err - } - return fmt.Errorf("unknown error: %s", errCode) -} diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go deleted file mode 100644 index 8815ac32..00000000 --- a/crypto/libolm/inboundgroupsession.go +++ /dev/null @@ -1,327 +0,0 @@ -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - -import ( - "bytes" - "encoding/base64" - "runtime" - "unsafe" - - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/id" -) - -// InboundGroupSession stores an inbound encrypted messaging session for a -// group. -type InboundGroupSession struct { - int *C.OlmInboundGroupSession - mem []byte -} - -// Ensure that [InboundGroupSession] implements [olm.InboundGroupSession]. -var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) - -// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled -// base64 string. Decrypts the InboundGroupSession using the supplied key. -// Returns error on failure. If the key doesn't match the one used to encrypt -// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". -func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - s := NewBlankInboundGroupSession() - return s, s.Unpickle(pickled, key) -} - -// NewInboundGroupSession creates a new inbound group session from a key -// exported from OutboundGroupSession.Key(). Returns error on failure. -// If the sessionKey is not valid base64 the error will be -// "OLM_INVALID_BASE64". If the session_key is invalid the error will be -// "OLM_BAD_SESSION_KEY". -func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankInboundGroupSession() - r := C.olm_init_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), - C.size_t(len(sessionKey)), - ) - runtime.KeepAlive(sessionKey) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// InboundGroupSessionImport imports an inbound group session from a previous -// export. Returns error on failure. If the sessionKey is not valid base64 -// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the -// error will be "OLM_BAD_SESSION_KEY". -func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankInboundGroupSession() - r := C.olm_import_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), - C.size_t(len(sessionKey)), - ) - runtime.KeepAlive(sessionKey) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// inboundGroupSessionSize is the size of an inbound group session object in -// bytes. -func inboundGroupSessionSize() uint { - return uint(C.olm_inbound_group_session_size()) -} - -// newInboundGroupSession initialises an empty InboundGroupSession. -func NewBlankInboundGroupSession() *InboundGroupSession { - memory := make([]byte, inboundGroupSessionSize()) - return &InboundGroupSession{ - int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to an -// inbound group session. -func (s *InboundGroupSession) lastError() error { - return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) -} - -// Clear clears the memory used to back this InboundGroupSession. -func (s *InboundGroupSession) Clear() error { - r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// pickleLen returns the number of bytes needed to store an inbound group -// session. -func (s *InboundGroupSession) pickleLen() uint { - return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) -} - -// Pickle returns an InboundGroupSession as a base64 string. Encrypts the -// InboundGroupSession using the supplied key. -func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(key) - if r == errorVal() { - return nil, s.lastError() - } - return pickled[:r], nil -} - -func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return olm.ErrNoKeyProvided - } else if len(pickled) == 0 { - return olm.ErrEmptyInput - } - r := C.olm_unpickle_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(key) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *InboundGroupSession) GobEncode() ([]byte, error) { - pickled, err := s.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankInboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { - pickled, err := s.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankInboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a -// given message could decode to. The actual size could be different due to -// padding. Returns error on failure. If the message base64 couldn't be -// decoded then the error will be "INVALID_BASE64". If the message is for an -// unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be "BAD_MESSAGE_FORMAT". -func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { - if len(message) == 0 { - return 0, olm.ErrEmptyInput - } - // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - messageCopy := bytes.Clone(message) - r := C.olm_group_decrypt_max_plaintext_length( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), - C.size_t(len(messageCopy)), - ) - runtime.KeepAlive(messageCopy) - if r == errorVal() { - return 0, s.lastError() - } - return uint(r), nil -} - -// Decrypt decrypts a message using the InboundGroupSession. Returns the the -// plain-text and message index on success. Returns error on failure. If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". If the -// message is for an unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the -// error will be "BAD_MESSAGE_MAC". If we do not have a session key -// corresponding to the message's index (ie, it was sent before the session key -// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". -func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { - if len(message) == 0 { - return nil, 0, olm.ErrEmptyInput - } - decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) - if err != nil { - return nil, 0, err - } - messageCopy := bytes.Clone(message) - plaintext := make([]byte, decryptMaxPlaintextLen) - var messageIndex uint32 - r := C.olm_group_decrypt( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), - C.size_t(len(messageCopy)), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), - C.size_t(len(plaintext)), - (*C.uint32_t)(unsafe.Pointer(&messageIndex)), - ) - runtime.KeepAlive(messageCopy) - if r == errorVal() { - return nil, 0, s.lastError() - } - return plaintext[:r], uint(messageIndex), nil -} - -// sessionIdLen returns the number of bytes needed to store a session ID. -func (s *InboundGroupSession) sessionIdLen() uint { - return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) -} - -// ID returns a base64-encoded identifier for this session. -func (s *InboundGroupSession) ID() id.SessionID { - sessionID := make([]byte, s.sessionIdLen()) - r := C.olm_inbound_group_session_id( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), - C.size_t(len(sessionID)), - ) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID[:r]) -} - -// FirstKnownIndex returns the first message index we know how to decrypt. -func (s *InboundGroupSession) FirstKnownIndex() uint32 { - return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) -} - -// IsVerified check if the session has been verified as a valid session. (A -// session is verified either because the original session share was signed, or -// because we have subsequently successfully decrypted a message.) -func (s *InboundGroupSession) IsVerified() bool { - return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) == 1 -} - -// exportLen returns the number of bytes needed to export an inbound group -// session. -func (s *InboundGroupSession) exportLen() uint { - return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) -} - -// Export returns the base64-encoded ratchet key for this session, at the given -// index, in a format which can be used by -// InboundGroupSession.InboundGroupSessionImport(). Encrypts the -// InboundGroupSession using the supplied key. Returns error on failure. -// if we do not have a session key corresponding to the given index (ie, it was -// sent before the session key was shared with us) the error will be -// "OLM_UNKNOWN_MESSAGE_INDEX". -func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { - key := make([]byte, s.exportLen()) - r := C.olm_export_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))), - C.size_t(len(key)), - C.uint32_t(messageIndex), - ) - if r == errorVal() { - return nil, s.lastError() - } - return key[:r], nil -} diff --git a/crypto/libolm/libolm.go b/crypto/libolm/libolm.go deleted file mode 100644 index 18815767..00000000 --- a/crypto/libolm/libolm.go +++ /dev/null @@ -1,10 +0,0 @@ -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - -// errorVal returns the value that olm functions return if there was an error. -func errorVal() C.size_t { - return C.olm_error() -} diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go deleted file mode 100644 index ca5b68f7..00000000 --- a/crypto/libolm/outboundgroupsession.go +++ /dev/null @@ -1,245 +0,0 @@ -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - -import ( - "crypto/rand" - "encoding/base64" - "runtime" - "unsafe" - - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/id" -) - -// OutboundGroupSession stores an outbound encrypted messaging session -// for a group. -type OutboundGroupSession struct { - int *C.OlmOutboundGroupSession - mem []byte -} - -// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. -var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) - -func NewOutboundGroupSession() (*OutboundGroupSession, error) { - s := NewBlankOutboundGroupSession() - random := make([]byte, s.createRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - return nil, err - } - r := C.olm_init_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))), - C.size_t(len(random)), - ) - runtime.KeepAlive(random) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// outboundGroupSessionSize is the size of an outbound group session object in -// bytes. -func outboundGroupSessionSize() uint { - return uint(C.olm_outbound_group_session_size()) -} - -// NewBlankOutboundGroupSession initialises an empty [OutboundGroupSession]. -func NewBlankOutboundGroupSession() *OutboundGroupSession { - memory := make([]byte, outboundGroupSessionSize()) - return &OutboundGroupSession{ - int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to an -// outbound group session. -func (s *OutboundGroupSession) lastError() error { - return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) -} - -// Clear clears the memory used to back this OutboundGroupSession. -func (s *OutboundGroupSession) Clear() error { - r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) - if r == errorVal() { - return s.lastError() - } else { - return nil - } -} - -// pickleLen returns the number of bytes needed to store an outbound group -// session. -func (s *OutboundGroupSession) pickleLen() uint { - return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the -// OutboundGroupSession using the supplied key. -func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(key) - if r == errorVal() { - return nil, s.lastError() - } - return pickled[:r], nil -} - -func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return olm.ErrNoKeyProvided - } - r := C.olm_unpickle_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(pickled) - runtime.KeepAlive(key) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *OutboundGroupSession) GobEncode() ([]byte, error) { - pickled, err := s.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankOutboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { - pickled, err := s.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankOutboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// createRandomLen returns the number of random bytes needed to create an -// Account. -func (s *OutboundGroupSession) createRandomLen() uint { - return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// encryptMsgLen returns the size of the next message in bytes for the given -// number of plain-text bytes. -func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { - return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { - if len(plaintext) == 0 { - return nil, olm.ErrEmptyInput - } - message := make([]byte, s.encryptMsgLen(len(plaintext))) - r := C.olm_group_encrypt( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), - C.size_t(len(plaintext)), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), - C.size_t(len(message)), - ) - runtime.KeepAlive(plaintext) - if r == errorVal() { - return nil, s.lastError() - } - return message[:r], nil -} - -// sessionIdLen returns the number of bytes needed to store a session ID. -func (s *OutboundGroupSession) sessionIdLen() uint { - return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// ID returns a base64-encoded identifier for this session. -func (s *OutboundGroupSession) ID() id.SessionID { - sessionID := make([]byte, s.sessionIdLen()) - r := C.olm_outbound_group_session_id( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), - C.size_t(len(sessionID)), - ) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID[:r]) -} - -// MessageIndex returns the message index for this session. Each message is -// sent with an increasing index; this returns the index for the next message. -func (s *OutboundGroupSession) MessageIndex() uint { - return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) -} - -// sessionKeyLen returns the number of bytes needed to store a session key. -func (s *OutboundGroupSession) sessionKeyLen() uint { - return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// Key returns the base64-encoded current ratchet key for this session. -func (s *OutboundGroupSession) Key() string { - sessionKey := make([]byte, s.sessionKeyLen()) - r := C.olm_outbound_group_session_key( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), - C.size_t(len(sessionKey)), - ) - if r == errorVal() { - panic(s.lastError()) - } - return string(sessionKey[:r]) -} diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go deleted file mode 100644 index ddf84613..00000000 --- a/crypto/libolm/register.go +++ /dev/null @@ -1,75 +0,0 @@ -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" -import ( - "unsafe" - - "maunium.net/go/mautrix/crypto/olm" -) - -var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") - -func Register() { - olm.Driver = "libolm" - - olm.GetVersion = func() (major, minor, patch uint8) { - C.olm_get_library_version( - (*C.uint8_t)(unsafe.Pointer(&major)), - (*C.uint8_t)(unsafe.Pointer(&minor)), - (*C.uint8_t)(unsafe.Pointer(&patch))) - return 3, 2, 15 - } - olm.SetPickleKeyImpl = func(key []byte) { - pickleKey = key - } - - olm.InitNewAccount = func() (olm.Account, error) { - return NewAccount() - } - olm.InitBlankAccount = func() olm.Account { - return NewBlankAccount() - } - olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { - return AccountFromPickled(pickled, key) - } - - olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { - return SessionFromPickled(pickled, key) - } - olm.InitNewBlankSession = func() olm.Session { - return NewBlankSession() - } - - olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } - olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { - return NewPKSigningFromSeed(seed) - } - olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { - return NewPkDecryption(privateKey) - } - - olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionFromPickled(pickled, key) - } - olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return NewInboundGroupSession(sessionKey) - } - olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionImport(sessionKey) - } - olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { - return NewBlankInboundGroupSession() - } - - olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankOutboundGroupSession() - return s, s.Unpickle(pickled, key) - } - olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } -} diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go deleted file mode 100644 index 1441df26..00000000 --- a/crypto/libolm/session.go +++ /dev/null @@ -1,401 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package libolm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -// #include -// #include -// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); -// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { -// if (olm_session_describe) { -// olm_session_describe(session, buf, buflen); -// } else { -// sprintf(buf, "olm_session_describe not supported"); -// } -// } -import "C" - -import ( - "crypto/rand" - "encoding/base64" - "runtime" - "unsafe" - - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/id" -) - -// Session stores an end to end encrypted messaging session. -type Session struct { - int *C.OlmSession - mem []byte -} - -// Ensure that [Session] implements [olm.Session]. -var _ olm.Session = (*Session)(nil) - -// sessionSize is the size of a session object in bytes. -func sessionSize() uint { - return uint(C.olm_session_size()) -} - -// SessionFromPickled loads a Session from a pickled base64 string. Decrypts -// the Session using the supplied key. Returns error on failure. If the key -// doesn't match the one used to encrypt the Session then the error will be -// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". -func SessionFromPickled(pickled, key []byte) (*Session, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankSession() - return s, s.Unpickle(pickled, key) -} - -func NewBlankSession() *Session { - memory := make([]byte, sessionSize()) - return &Session{ - int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to a -// session. -func (s *Session) lastError() error { - return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) -} - -// Clear clears the memory used to back this Session. -func (s *Session) Clear() error { - r := C.olm_clear_session((*C.OlmSession)(s.int)) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// pickleLen returns the number of bytes needed to store a session. -func (s *Session) pickleLen() uint { - return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) -} - -// createOutboundRandomLen returns the number of random bytes needed to create -// an outbound session. -func (s *Session) createOutboundRandomLen() uint { - return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) -} - -// idLen returns the length of the buffer needed to return the id for this -// session. -func (s *Session) idLen() uint { - return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) -} - -// encryptRandomLen returns the number of random bytes needed to encrypt the -// next message. -func (s *Session) encryptRandomLen() uint { - return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) -} - -// encryptMsgLen returns the size of the next message in bytes for the given -// number of plain-text bytes. -func (s *Session) encryptMsgLen(plainTextLen int) uint { - return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) -} - -// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a -// given message could decode to. The actual size could be different due to -// padding. Returns error on failure. If the message base64 couldn't be -// decoded then the error will be "INVALID_BASE64". If the message is for an -// unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be "BAD_MESSAGE_FORMAT". -func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { - if len(message) == 0 { - return 0, olm.ErrEmptyInput - } - messageCopy := []byte(message) - r := C.olm_decrypt_max_plaintext_length( - (*C.OlmSession)(s.int), - C.size_t(msgType), - unsafe.Pointer(unsafe.SliceData((messageCopy))), - C.size_t(len(messageCopy)), - ) - runtime.KeepAlive(messageCopy) - if r == errorVal() { - return 0, s.lastError() - } - return uint(r), nil -} - -// Pickle returns a Session as a base64 string. Encrypts the Session using the -// supplied key. -func (s *Session) Pickle(key []byte) ([]byte, error) { - if len(key) == 0 { - return nil, olm.ErrNoKeyProvided - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled))) - runtime.KeepAlive(key) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r], nil -} - -// Unpickle unpickles the base64-encoded Olm session decrypting it with the -// provided key. This function mutates the input pickled data slice. -func (s *Session) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return olm.ErrNoKeyProvided - } - r := C.olm_unpickle_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), - C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled))) - runtime.KeepAlive(pickled) - runtime.KeepAlive(key) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *Session) GobEncode() ([]byte, error) { - pickled, err := s.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *Session) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *Session) MarshalJSON() ([]byte, error) { - pickled, err := s.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *Session) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// Id returns an identifier for this Session. Will be the same for both ends -// of the conversation. -func (s *Session) ID() id.SessionID { - sessionID := make([]byte, s.idLen()) - r := C.olm_session_id( - (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(sessionID)), - C.size_t(len(sessionID)), - ) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID) -} - -// HasReceivedMessage returns true if this session has received any message. -func (s *Session) HasReceivedMessage() bool { - switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { - case 0: - return false - default: - return true - } -} - -// MatchesInboundSession checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". If the message was for an unsupported protocol version -// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be -// decoded then then the error will be "BAD_MESSAGE_FORMAT". -func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { - if len(oneTimeKeyMsg) == 0 { - return false, olm.ErrEmptyInput - } - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) - r := C.olm_matches_inbound_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(oneTimeKeyMsgCopy) - if r == 1 { - return true, nil - } else if r == 0 { - return false, nil - } else { // if r == errorVal() - return false, s.lastError() - } -} - -// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". If the message was for an unsupported protocol version -// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be -// decoded then then the error will be "BAD_MESSAGE_FORMAT". -func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, olm.ErrEmptyInput - } - theirIdentityKeyCopy := []byte(theirIdentityKey) - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) - r := C.olm_matches_inbound_session_from( - (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), - C.size_t(len(theirIdentityKeyCopy)), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(theirIdentityKeyCopy) - runtime.KeepAlive(oneTimeKeyMsgCopy) - if r == 1 { - return true, nil - } else if r == 0 { - return false, nil - } else { // if r == errorVal() - return false, s.lastError() - } -} - -// EncryptMsgType returns the type of the next message that Encrypt will -// return. Returns MsgTypePreKey if the message will be a PRE_KEY message. -// Returns MsgTypeMsg if the message will be a normal message. Returns error -// on failure. -func (s *Session) EncryptMsgType() id.OlmMsgType { - switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { - case C.size_t(id.OlmMsgTypePreKey): - return id.OlmMsgTypePreKey - case C.size_t(id.OlmMsgTypeMsg): - return id.OlmMsgTypeMsg - default: - panic("olm_encrypt_message_type returned invalid result") - } -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { - if len(plaintext) == 0 { - return 0, nil, olm.ErrEmptyInput - } - // Make the slice be at least length 1 - random := make([]byte, s.encryptRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - // TODO can we just return err here? - return 0, nil, olm.ErrNotEnoughGoRandom - } - messageType := s.EncryptMsgType() - message := make([]byte, s.encryptMsgLen(len(plaintext))) - r := C.olm_encrypt( - (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(plaintext)), - C.size_t(len(plaintext)), - unsafe.Pointer(unsafe.SliceData(random)), - C.size_t(len(random)), - unsafe.Pointer(unsafe.SliceData(message)), - C.size_t(len(message)), - ) - runtime.KeepAlive(plaintext) - runtime.KeepAlive(random) - if r == errorVal() { - return 0, nil, s.lastError() - } - return messageType, message[:r], nil -} - -// Decrypt decrypts a message using the Session. Returns the the plain-text on -// success. Returns error on failure. If the base64 couldn't be decoded then -// the error will be "INVALID_BASE64". If the message is for an unsupported -// version of the protocol then the error will be "BAD_MESSAGE_VERSION". If -// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". -// If the MAC on the message was invalid then the error will be -// "BAD_MESSAGE_MAC". -func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { - if len(message) == 0 { - return nil, olm.ErrEmptyInput - } - decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) - if err != nil { - return nil, err - } - messageCopy := []byte(message) - plaintext := make([]byte, decryptMaxPlaintextLen) - r := C.olm_decrypt( - (*C.OlmSession)(s.int), - C.size_t(msgType), - unsafe.Pointer(unsafe.SliceData(messageCopy)), - C.size_t(len(messageCopy)), - unsafe.Pointer(unsafe.SliceData(plaintext)), - C.size_t(len(plaintext)), - ) - runtime.KeepAlive(messageCopy) - if r == errorVal() { - return nil, s.lastError() - } - return plaintext[:r], nil -} - -// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 -const maxDescribeSize = 600 - -// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. -func (s *Session) Describe() string { - desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) - defer C.free(unsafe.Pointer(desc)) - C.meowlm_session_describe( - (*C.OlmSession)(s.int), - desc, - C.size_t(maxDescribeSize), - ) - return C.GoString(desc) -} diff --git a/crypto/machine.go b/crypto/machine.go index fa051f94..2477b9e1 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -11,16 +11,13 @@ import ( "errors" "fmt" "sync" - "sync/atomic" "time" "github.com/rs/zerolog" - "go.mau.fi/util/ptr" "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -35,12 +32,7 @@ type OlmMachine struct { CryptoStore Store StateStore StateStore - backgroundCtx context.Context - cancelBackgroundCtx context.CancelFunc - - PlaintextMentions bool - MSC4392Relations bool - AllowEncryptedState bool + PlaintextMentions bool // Never ask the server for keys automatically as a side effect during Megolm decryption. DisableDecryptKeyFetching bool @@ -48,8 +40,6 @@ type OlmMachine struct { // Don't mark outbound Olm sessions as shared for devices they were initially sent to. DisableSharedGroupSessionTracking bool - IgnorePostDecryptionParseErrors bool - SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState @@ -70,17 +60,13 @@ type OlmMachine struct { devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedgedLock sync.Mutex - olmHashSavePoints []time.Time - lastHashDelete time.Time - olmHashSavePointLock sync.Mutex olmLock sync.Mutex megolmEncryptLock sync.Mutex megolmDecryptLock sync.Mutex - otkUploadLock sync.Mutex - lastOTKUpload time.Time - receivedOTKsForSelf atomic.Bool + otkUploadLock sync.Mutex + lastOTKUpload time.Time CrossSigningKeys *CrossSigningKeysCache crossSigningPubkeys *CrossSigningPublicKeysCache @@ -136,7 +122,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor recentlyUnwedged: make(map[id.IdentityKey]time.Time), secretListeners: make(map[string]chan<- string), } - mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background()) mach.AllowKeyShare = mach.defaultAllowKeyShare return mach } @@ -149,11 +134,6 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) { - mach.cancelBackgroundCtx() - mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx) -} - // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. func (mach *OlmMachine) Load(ctx context.Context) (err error) { @@ -164,23 +144,9 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { if mach.account == nil { mach.account = NewOlmAccount() } - zerolog.Ctx(ctx).Debug(). - Str("machine_ptr", fmt.Sprintf("%p", mach)). - Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)). - Str("olm_driver", olm.Driver). - Msg("Loaded olm account") return nil } -func (mach *OlmMachine) Destroy() { - mach.Log.Debug(). - Str("machine_ptr", fmt.Sprintf("%p", mach)). - Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)). - Msg("Destroying olm machine") - mach.cancelBackgroundCtx() - // TODO actually destroy something? -} - func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { @@ -206,7 +172,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Since(start) + duration := time.Now().Sub(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). @@ -242,14 +208,13 @@ func (mach *OlmMachine) OwnIdentity() *id.Device { } } -type ASEventProcessor interface { +type asEventProcessor interface { On(evtType event.Type, handler func(ctx context.Context, evt *event.Event)) OnOTK(func(ctx context.Context, otk *mautrix.OTKCount)) OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string)) - Dispatch(ctx context.Context, evt *event.Event) } -func (mach *OlmMachine) AddAppserviceListener(ep ASEventProcessor) { +func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { // ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent) ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent) @@ -279,29 +244,14 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic } } -func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool { - if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID { - return false - } - switch id.Ed25519(otkCount.DeviceID) { - case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey: - return true - } - return false -} - func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { - receivedOTKsForSelf := mach.receivedOTKsForSelf.Load() if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { - if otkCount.UserID != mach.Client.UserID || (!receivedOTKsForSelf && !mach.otkCountIsForCrossSigningKey(otkCount)) { - mach.Log.Warn(). - Str("target_user_id", otkCount.UserID.String()). - Str("target_device_id", otkCount.DeviceID.String()). - Msg("Dropping OTK counts targeted to someone else") - } + // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions + mach.Log.Warn(). + Str("target_user_id", otkCount.UserID.String()). + Str("target_device_id", otkCount.DeviceID.String()). + Msg("Dropping OTK counts targeted to someone else") return - } else if !receivedOTKsForSelf { - mach.receivedOTKsForSelf.Store(true) } minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2 @@ -310,7 +260,7 @@ func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.O log := mach.Log.With().Str("trace_id", traceID).Logger() ctx = log.WithContext(ctx) log.Debug(). - Int("keys_left", otkCount.SignedCurve25519). + Int("keys_left", otkCount.Curve25519). Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...") err := mach.ShareKeys(ctx, otkCount.SignedCurve25519) if err != nil { @@ -340,7 +290,6 @@ func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.R } mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount) - mach.MarkOlmHashSavePoint(ctx) return true } @@ -383,20 +332,20 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) Msg("Got membership state change, invalidating group session in room") err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { - mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session") + mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") } } -func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) *DecryptedOlmEvent { +func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) { if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok { mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler") - return nil + return } decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) if err != nil { mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event") - return nil + return } log := mach.machOrContextLog(ctx).With(). @@ -425,37 +374,6 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve log.Trace().Msg("Handled secret send event") default: log.Debug().Msg("Unhandled encrypted to-device event") - return decryptedEvt - } - return nil -} - -const olmHashSavePointCount = 5 -const olmHashDeleteMinInterval = 10 * time.Minute -const minSavePointInterval = 1 * time.Minute - -// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed. -// -// This should be called after all to-device events in a sync have been processed. -// The function will then delete old olm hashes after enough syncs have happened -// (such that it's unlikely for the olm messages to repeat). -func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) { - mach.olmHashSavePointLock.Lock() - defer mach.olmHashSavePointLock.Unlock() - if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval { - return - } - mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now()) - if len(mach.olmHashSavePoints) > olmHashSavePointCount { - sp := mach.olmHashSavePoints[0] - mach.olmHashSavePoints = mach.olmHashSavePoints[1:] - if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval { - err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp) - mach.lastHashDelete = time.Now() - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes") - } - } } } @@ -603,10 +521,10 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session") + log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) + mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -617,7 +535,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen return nil } -func (mach *OlmMachine) MarkSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { +func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { if mach.SessionReceived != nil { mach.SessionReceived(ctx, roomID, id, firstKnownIndex) } @@ -730,7 +648,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro start := time.Now() mach.otkUploadLock.Lock() defer mach.otkUploadLock.Unlock() - if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) { + if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go deleted file mode 100644 index fd40d795..00000000 --- a/crypto/machine_bench_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package crypto_test - -import ( - "context" - "fmt" - "math/rand/v2" - "testing" - - "github.com/rs/zerolog" - globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/cryptohelper" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/mockserver" -) - -func randomDeviceCount(r *rand.Rand) int { - k := 1 - for k < 10 && r.IntN(3) > 0 { - k++ - } - return k -} - -func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) { - globallog.Logger = zerolog.Nop() - server := mockserver.Create(b) - server.PopOTKs = false - server.MemoryStore = false - var i int - var shareTargets []id.UserID - r := rand.New(rand.NewPCG(293, 0)) - var totalDeviceCount int - for i = 1; i < 1000; i++ { - userID := id.UserID(fmt.Sprintf("@user%d:localhost", i)) - deviceCount := randomDeviceCount(r) - for j := 0; j < deviceCount; j++ { - client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j))) - mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() - keysCache, err := mach.GenerateCrossSigningKeys() - require.NoError(b, err) - err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) - require.NoError(b, err) - } - totalDeviceCount += deviceCount - shareTargets = append(shareTargets, userID) - } - for b.Loop() { - client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i))) - mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() - keysCache, err := mach.GenerateCrossSigningKeys() - require.NoError(b, err) - err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) - require.NoError(b, err) - err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets) - require.NoError(b, err) - i++ - } - fmt.Println(totalDeviceCount, "devices total") -} diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 872c3ac4..59c86236 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -36,15 +36,20 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, func newMachine(t *testing.T, userID id.UserID) *OlmMachine { client, err := mautrix.NewClient("http://localhost", userID, "token") - require.NoError(t, err, "Error creating client") + if err != nil { + t.Fatalf("Error creating client: %v", err) + } client.DeviceID = "device1" gobStore := NewMemoryStore(nil) - require.NoError(t, err, "Error creating Gob store") + if err != nil { + t.Fatalf("Error creating Gob store: %v", err) + } machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - err = machine.Load(context.TODO()) - require.NoError(t, err, "Error creating account") + if err := machine.Load(context.TODO()); err != nil { + t.Fatalf("Error creating account: %v", err) + } return machine } @@ -77,7 +82,9 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) - require.NoError(t, err, "Error creating outbound olm session") + if err != nil { + t.Errorf("Failed to create outbound olm session: %v", err) + } // store sender device identity in receiving machine store machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ @@ -114,21 +121,29 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { Type: event.ToDeviceEncrypted, Sender: "user1", }, senderKey, content.Type, content.Body) - require.NoError(t, err, "Error decrypting olm ciphertext") - + if err != nil { + t.Errorf("Error decrypting olm content: %v", err) + } // store room key in new inbound group session roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false) - require.NoError(t, err, "Error creating inbound group session") - err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs) - require.NoError(t, err, "Error storing inbound group session") + if err != nil { + t.Errorf("Error creating inbound megolm session: %v", err) + } + if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil { + t.Errorf("Error storing inbound megolm session: %v", err) + } } // encrypt event with megolm session in sending machine eventContent := map[string]string{"hello": "world"} encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - require.NoError(t, err, "Error encrypting megolm event") - assert.Equal(t, 1, megolmOutSession.MessageCount) + if err != nil { + t.Errorf("Error encrypting megolm event: %v", err) + } + if megolmOutSession.MessageCount != 1 { + t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount) + } encryptedEvt := &event.Event{ Content: event.Content{Parsed: encryptedEvtContent}, @@ -140,12 +155,22 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // decrypt event on receiving machine and confirm decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt) - require.NoError(t, err, "Error decrypting megolm event") - assert.Equal(t, event.EventMessage, decryptedEvt.Type) - assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"]) + if err != nil { + t.Errorf("Error decrypting megolm event: %v", err) + } + if decryptedEvt.Type != event.EventMessage { + t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type) + } + if decryptedEvt.Content.Raw["hello"] != "world" { + t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw) + } machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message") + if megolmOutSession.Expired() { + t.Error("Megolm outbound session expired before 3rd message") + } machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message") + if !megolmOutSession.Expired() { + t.Error("Megolm outbound session not expired after 3rd message") + } } diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 2ec5dd70..37458d1b 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -1,105 +1,28 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !goolm package olm +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "unsafe" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) -type Account interface { - // Pickle returns an Account as a base64 string. Encrypts the Account using the - // supplied key. - Pickle(key []byte) ([]byte, error) - - // Unpickle loads an Account from a pickled base64 string. Decrypts the - // Account using the supplied key. Returns error on failure. - Unpickle(pickled, key []byte) error - - // IdentityKeysJSON returns the public parts of the identity keys for the Account. - IdentityKeysJSON() ([]byte, error) - - // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity - // keys for the Account. - IdentityKeys() (id.Ed25519, id.Curve25519, error) - - // Sign returns the signature of a message using the ed25519 key for this - // Account. - Sign(message []byte) ([]byte, error) - - // OneTimeKeys returns the public parts of the unpublished one time keys for - // the Account. - // - // The returned data is a struct with the single value "Curve25519", which is - // itself an object mapping key id to base64-encoded Curve25519 key. For - // example: - // - // { - // Curve25519: { - // "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", - // "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" - // } - // } - OneTimeKeys() (map[string]id.Curve25519, error) - - // MarkKeysAsPublished marks the current set of one time keys as being - // published. - MarkKeysAsPublished() - - // MaxNumberOfOneTimeKeys returns the largest number of one time keys this - // Account can store. - MaxNumberOfOneTimeKeys() uint - - // GenOneTimeKeys generates a number of new one time keys. If the total - // number of keys stored by this Account exceeds MaxNumberOfOneTimeKeys - // then the old keys are discarded. - GenOneTimeKeys(num uint) error - - // NewOutboundSession creates a new out-bound session for sending messages to a - // given curve25519 identityKey and oneTimeKey. Returns error on failure. If the - // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" - NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error) - - // NewInboundSession creates a new in-bound session for sending/receiving - // messages from an incoming PRE_KEY message. Returns error on failure. If - // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If - // the message was for an unsupported protocol version then the error will be - // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the - // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one - // time key then the error will be "BAD_MESSAGE_KEY_ID". - NewInboundSession(oneTimeKeyMsg string) (Session, error) - - // NewInboundSessionFrom creates a new in-bound session for sending/receiving - // messages from an incoming PRE_KEY message. Returns error on failure. If - // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If - // the message was for an unsupported protocol version then the error will be - // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the - // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one - // time key then the error will be "BAD_MESSAGE_KEY_ID". - NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error) - - // RemoveOneTimeKeys removes the one time keys that the session used from the - // Account. Returns error on failure. If the Account doesn't have any - // matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". - RemoveOneTimeKeys(s Session) error -} - -var Driver = "none" - -var InitBlankAccount func() Account -var InitNewAccount func() (Account, error) -var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) - -// NewAccount creates a new Account. -func NewAccount() (Account, error) { - return InitNewAccount() -} - -func NewBlankAccount() Account { - return InitBlankAccount() +// Account stores a device account for end to end encrypted messaging. +type Account struct { + int *C.OlmAccount + mem []byte } // AccountFromPickled loads an Account from a pickled base64 string. Decrypts @@ -107,6 +30,375 @@ func NewBlankAccount() Account { // doesn't match the one used to encrypt the Account then the error will be // "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". -func AccountFromPickled(pickled, key []byte) (Account, error) { - return InitNewAccountFromPickled(pickled, key) +func AccountFromPickled(pickled, key []byte) (*Account, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + a := NewBlankAccount() + return a, a.Unpickle(pickled, key) +} + +func NewBlankAccount() *Account { + memory := make([]byte, accountSize()) + return &Account{ + int: C.olm_account(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// NewAccount creates a new Account. +func NewAccount() *Account { + a := NewBlankAccount() + random := make([]byte, a.createRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(NotEnoughGoRandom) + } + r := C.olm_create_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + panic(a.lastError()) + } else { + return a + } +} + +// accountSize returns the size of an account object in bytes. +func accountSize() uint { + return uint(C.olm_account_size()) +} + +// lastError returns an error describing the most recent error to happen to an +// account. +func (a *Account) lastError() error { + return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) +} + +// Clear clears the memory used to back this Account. +func (a *Account) Clear() error { + r := C.olm_clear_account((*C.OlmAccount)(a.int)) + if r == errorVal() { + return a.lastError() + } else { + return nil + } +} + +// pickleLen returns the number of bytes needed to store an Account. +func (a *Account) pickleLen() uint { + return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) +} + +// createRandomLen returns the number of random bytes needed to create an +// Account. +func (a *Account) createRandomLen() uint { + return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) +} + +// identityKeysLen returns the size of the output buffer needed to hold the +// identity keys. +func (a *Account) identityKeysLen() uint { + return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) +} + +// signatureLen returns the length of an ed25519 signature encoded as base64. +func (a *Account) signatureLen() uint { + return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) +} + +// oneTimeKeysLen returns the size of the output buffer needed to hold the one +// time keys. +func (a *Account) oneTimeKeysLen() uint { + return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) +} + +// genOneTimeKeysRandomLen returns the number of random bytes needed to +// generate a given number of new one time keys. +func (a *Account) genOneTimeKeysRandomLen(num uint) uint { + return uint(C.olm_account_generate_one_time_keys_random_length( + (*C.OlmAccount)(a.int), + C.size_t(num))) +} + +// Pickle returns an Account as a base64 string. Encrypts the Account using the +// supplied key. +func (a *Account) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled := make([]byte, a.pickleLen()) + r := C.olm_pickle_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + panic(a.lastError()) + } + return pickled[:r] +} + +func (a *Account) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } + r := C.olm_unpickle_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return a.lastError() + } + return nil +} + +// Deprecated +func (a *Account) GobEncode() ([]byte, error) { + pickled := a.Pickle(pickleKey) + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (a *Account) GobDecode(rawPickled []byte) error { + if a.int == nil { + *a = *NewBlankAccount() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return a.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (a *Account) MarshalJSON() ([]byte, error) { + pickled := a.Pickle(pickleKey) + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (a *Account) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return InputNotJSONString + } + if a.int == nil { + *a = *NewBlankAccount() + } + return a.Unpickle(data[1:len(data)-1], pickleKey) +} + +// IdentityKeysJSON returns the public parts of the identity keys for the Account. +func (a *Account) IdentityKeysJSON() []byte { + identityKeys := make([]byte, a.identityKeysLen()) + r := C.olm_account_identity_keys( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&identityKeys[0]), + C.size_t(len(identityKeys))) + if r == errorVal() { + panic(a.lastError()) + } else { + return identityKeys + } +} + +// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity +// keys for the Account. +func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) { + identityKeysJSON := a.IdentityKeysJSON() + results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") + return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str) +} + +// Sign returns the signature of a message using the ed25519 key for this +// Account. +func (a *Account) Sign(message []byte) []byte { + if len(message) == 0 { + panic(EmptyInput) + } + signature := make([]byte, a.signatureLen()) + r := C.olm_account_sign( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&message[0]), + C.size_t(len(message)), + unsafe.Pointer(&signature[0]), + C.size_t(len(signature))) + if r == errorVal() { + panic(a.lastError()) + } + return signature +} + +// SignJSON signs the given JSON object following the Matrix specification: +// https://matrix.org/docs/spec/appendices#signing-json +func (a *Account) SignJSON(obj interface{}) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil +} + +// OneTimeKeys returns the public parts of the unpublished one time keys for +// the Account. +// +// The returned data is a struct with the single value "Curve25519", which is +// itself an object mapping key id to base64-encoded Curve25519 key. For +// example: +// +// { +// Curve25519: { +// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", +// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" +// } +// } +func (a *Account) OneTimeKeys() map[string]id.Curve25519 { + oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) + r := C.olm_account_one_time_keys( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&oneTimeKeysJSON[0]), + C.size_t(len(oneTimeKeysJSON))) + if r == errorVal() { + panic(a.lastError()) + } + var oneTimeKeys struct { + Curve25519 map[string]id.Curve25519 `json:"curve25519"` + } + err := json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) + if err != nil { + panic(err) + } + return oneTimeKeys.Curve25519 +} + +// MarkKeysAsPublished marks the current set of one time keys as being +// published. +func (a *Account) MarkKeysAsPublished() { + C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) +} + +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) +} + +// GenOneTimeKeys generates a number of new one time keys. If the total number +// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old +// keys are discarded. +func (a *Account) GenOneTimeKeys(num uint) { + random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) + _, err := rand.Read(random) + if err != nil { + panic(NotEnoughGoRandom) + } + r := C.olm_account_generate_one_time_keys( + (*C.OlmAccount)(a.int), + C.size_t(num), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + panic(a.lastError()) + } +} + +// NewOutboundSession creates a new out-bound session for sending messages to a +// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the +// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { + if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { + return nil, EmptyInput + } + s := NewBlankSession() + random := make([]byte, s.createOutboundRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(NotEnoughGoRandom) + } + r := C.olm_create_outbound_session( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(&([]byte(theirIdentityKey)[0])), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), + C.size_t(len(theirOneTimeKey)), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. If +// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If +// the message was for an unsupported protocol version then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the +// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one +// time key then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { + if len(oneTimeKeyMsg) == 0 { + return nil, EmptyInput + } + s := NewBlankSession() + r := C.olm_create_inbound_session( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), + C.size_t(len(oneTimeKeyMsg))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// NewInboundSessionFrom creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. If +// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If +// the message was for an unsupported protocol version then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the +// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one +// time key then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { + if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { + return nil, EmptyInput + } + s := NewBlankSession() + r := C.olm_create_inbound_session_from( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(&([]byte(theirIdentityKey)[0])), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), + C.size_t(len(oneTimeKeyMsg))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// RemoveOneTimeKeys removes the one time keys that the session used from the +// Account. Returns error on failure. If the Account doesn't have any +// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) RemoveOneTimeKeys(s *Session) error { + r := C.olm_remove_one_time_keys( + (*C.OlmAccount)(a.int), + (*C.OlmSession)(s.int)) + if r == errorVal() { + return a.lastError() + } + return nil } diff --git a/crypto/olm/account_goolm.go b/crypto/olm/account_goolm.go new file mode 100644 index 00000000..eeff54f9 --- /dev/null +++ b/crypto/olm/account_goolm.go @@ -0,0 +1,154 @@ +//go:build goolm + +package olm + +import ( + "encoding/json" + + "github.com/tidwall/sjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/id" +) + +// Account stores a device account for end to end encrypted messaging. +type Account struct { + account.Account +} + +// NewAccount creates a new Account. +func NewAccount() *Account { + a, err := account.NewAccount(nil) + if err != nil { + panic(err) + } + ac := &Account{} + ac.Account = *a + return ac +} + +func NewBlankAccount() *Account { + return &Account{} +} + +// Clear clears the memory used to back this Account. +func (a *Account) Clear() error { + a.Account = account.Account{} + return nil +} + +// Pickle returns an Account as a base64 string. Encrypts the Account using the +// supplied key. +func (a *Account) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled, err := a.Account.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +// IdentityKeysJSON returns the public parts of the identity keys for the Account. +func (a *Account) IdentityKeysJSON() []byte { + identityKeys, err := a.Account.IdentityKeysJSON() + if err != nil { + panic(err) + } + return identityKeys +} + +// Sign returns the signature of a message using the ed25519 key for this +// Account. +func (a *Account) Sign(message []byte) []byte { + if len(message) == 0 { + panic(EmptyInput) + } + signature, err := a.Account.Sign(message) + if err != nil { + panic(err) + } + return signature +} + +// SignJSON signs the given JSON object following the Matrix specification: +// https://matrix.org/docs/spec/appendices#signing-json +func (a *Account) SignJSON(obj interface{}) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil +} + +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(account.MaxOneTimeKeys) +} + +// GenOneTimeKeys generates a number of new one time keys. If the total number +// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old +// keys are discarded. +func (a *Account) GenOneTimeKeys(num uint) { + err := a.Account.GenOneTimeKeys(nil, num) + if err != nil { + panic(err) + } +} + +// NewOutboundSession creates a new out-bound session for sending messages to a +// given curve25519 identityKey and oneTimeKey. Returns error on failure. +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { + if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { + return nil, EmptyInput + } + s := &Session{} + newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey) + if err != nil { + return nil, err + } + s.OlmSession = *newSession + return s, nil +} + +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { + if len(oneTimeKeyMsg) == 0 { + return nil, EmptyInput + } + s := &Session{} + newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg)) + if err != nil { + return nil, err + } + s.OlmSession = *newSession + return s, nil +} + +// NewInboundSessionFrom creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. +func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { + if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { + return nil, EmptyInput + } + s := &Session{} + newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg)) + if err != nil { + return nil, err + } + s.OlmSession = *newSession + return s, nil +} + +// RemoveOneTimeKeys removes the one time keys that the session used from the +// Account. Returns error on failure. +func (a *Account) RemoveOneTimeKeys(s *Session) error { + a.Account.RemoveOneTimeKeys(&s.OlmSession) + return nil +} diff --git a/crypto/olm/account_test.go b/crypto/olm/account_test.go deleted file mode 100644 index 0e055881..00000000 --- a/crypto/olm/account_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package olm_test - -import ( - "bytes" - "encoding/base64" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/crypto/ed25519" - "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/crypto/libolm" - "maunium.net/go/mautrix/crypto/olm" -) - -func ensureAccountsEqual(t *testing.T, a, b olm.Account) { - t.Helper() - - assert.Equal(t, a.MaxNumberOfOneTimeKeys(), b.MaxNumberOfOneTimeKeys()) - - aEd25519, aCurve25519, err := a.IdentityKeys() - require.NoError(t, err) - bEd25519, bCurve25519, err := b.IdentityKeys() - require.NoError(t, err) - assert.Equal(t, aEd25519, bEd25519) - assert.Equal(t, aCurve25519, bCurve25519) - - aIdentityKeysJSON, err := a.IdentityKeysJSON() - require.NoError(t, err) - bIdentityKeysJSON, err := b.IdentityKeysJSON() - require.NoError(t, err) - assert.JSONEq(t, string(aIdentityKeysJSON), string(bIdentityKeysJSON)) - - aOTKs, err := a.OneTimeKeys() - require.NoError(t, err) - bOTKs, err := b.OneTimeKeys() - require.NoError(t, err) - assert.Equal(t, aOTKs, bOTKs) -} - -// TestAccount_UnpickleLibolmToGoolm tests creating an account from libolm, -// pickling it, and importing it into goolm. -func TestAccount_UnpickleLibolmToGoolm(t *testing.T) { - libolmAccount, err := libolm.NewAccount() - require.NoError(t, err) - - require.NoError(t, libolmAccount.GenOneTimeKeys(50)) - - libolmPickled, err := libolmAccount.Pickle([]byte("test")) - require.NoError(t, err) - - goolmAccount, err := account.AccountFromPickled(libolmPickled, []byte("test")) - require.NoError(t, err) - - ensureAccountsEqual(t, libolmAccount, goolmAccount) - - goolmPickled, err := goolmAccount.Pickle([]byte("test")) - require.NoError(t, err) - assert.Equal(t, libolmPickled, goolmPickled) -} - -// TestAccount_UnpickleGoolmToLibolm tests creating an account from goolm, -// pickling it, and importing it into libolm. -func TestAccount_UnpickleGoolmToLibolm(t *testing.T) { - goolmAccount, err := account.NewAccount() - require.NoError(t, err) - - require.NoError(t, goolmAccount.GenOneTimeKeys(50)) - - goolmPickled, err := goolmAccount.Pickle([]byte("test")) - require.NoError(t, err) - - libolmAccount, err := libolm.AccountFromPickled(bytes.Clone(goolmPickled), []byte("test")) - require.NoError(t, err) - - ensureAccountsEqual(t, libolmAccount, goolmAccount) - - libolmPickled, err := libolmAccount.Pickle([]byte("test")) - require.NoError(t, err) - assert.Equal(t, goolmPickled, libolmPickled) -} - -func FuzzAccount_Sign(f *testing.F) { - f.Add([]byte("anything")) - - libolmAccount := exerrors.Must(libolm.NewAccount()) - goolmAccount := exerrors.Must(account.AccountFromPickled(exerrors.Must(libolmAccount.Pickle([]byte("test"))), []byte("test"))) - - f.Fuzz(func(t *testing.T, message []byte) { - if len(message) == 0 { - t.Skip("empty message is not supported") - } - - libolmSignature, err := libolmAccount.Sign(bytes.Clone(message)) - require.NoError(t, err) - goolmSignature, err := goolmAccount.Sign(bytes.Clone(message)) - require.NoError(t, err) - assert.Equal(t, goolmSignature, libolmSignature) - - goolmSignatureBytes, err := base64.RawStdEncoding.DecodeString(string(goolmSignature)) - require.NoError(t, err) - libolmSignatureBytes, err := base64.RawStdEncoding.DecodeString(string(libolmSignature)) - require.NoError(t, err) - - libolmEd25519, _, err := libolmAccount.IdentityKeys() - require.NoError(t, err) - - assert.True(t, ed25519.Verify(ed25519.PublicKey(libolmEd25519.Bytes()), message, libolmSignatureBytes)) - assert.True(t, ed25519.Verify(ed25519.PublicKey(libolmEd25519.Bytes()), message, goolmSignatureBytes)) - - assert.True(t, goolmAccount.IdKeys.Ed25519.Verify(bytes.Clone(message), libolmSignatureBytes)) - assert.True(t, goolmAccount.IdKeys.Ed25519.Verify(bytes.Clone(message), goolmSignatureBytes)) - }) -} diff --git a/crypto/olm/error.go b/crypto/olm/error.go new file mode 100644 index 00000000..63352e20 --- /dev/null +++ b/crypto/olm/error.go @@ -0,0 +1,62 @@ +//go:build !goolm + +package olm + +import ( + "errors" + "fmt" +) + +// Error codes from go-olm +var ( + EmptyInput = errors.New("empty input") + NoKeyProvided = errors.New("no pickle key provided") + NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + InputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Error codes from olm code +var ( + NotEnoughRandom = errors.New("not enough entropy was supplied") + OutputBufferTooSmall = errors.New("supplied output buffer is too small") + BadMessageVersion = errors.New("the message version is unsupported") + BadMessageFormat = errors.New("the message couldn't be decoded") + BadMessageMAC = errors.New("the message couldn't be decrypted") + BadMessageKeyID = errors.New("the message references an unknown key ID") + InvalidBase64 = errors.New("the input base64 was invalid") + BadAccountKey = errors.New("the supplied account key is invalid") + UnknownPickleVersion = errors.New("the pickled object is too new") + CorruptedPickle = errors.New("the pickled object couldn't be decoded") + BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") + BadSignature = errors.New("received message had a bad signature") + InputBufferTooSmall = errors.New("the input data was too small to be valid") +) + +var errorMap = map[string]error{ + "NOT_ENOUGH_RANDOM": NotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": OutputBufferTooSmall, + "BAD_MESSAGE_VERSION": BadMessageVersion, + "BAD_MESSAGE_FORMAT": BadMessageFormat, + "BAD_MESSAGE_MAC": BadMessageMAC, + "BAD_MESSAGE_KEY_ID": BadMessageKeyID, + "INVALID_BASE64": InvalidBase64, + "BAD_ACCOUNT_KEY": BadAccountKey, + "UNKNOWN_PICKLE_VERSION": UnknownPickleVersion, + "CORRUPTED_PICKLE": CorruptedPickle, + "BAD_SESSION_KEY": BadSessionKey, + "UNKNOWN_MESSAGE_INDEX": UnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": BadLegacyAccountPickle, + "BAD_SIGNATURE": BadSignature, + "INPUT_BUFFER_TOO_SMALL": InputBufferTooSmall, +} + +func convertError(errCode string) error { + err, ok := errorMap[errCode] + if ok { + return err + } + return fmt.Errorf("unknown error: %s", errCode) +} diff --git a/crypto/olm/error_goolm.go b/crypto/olm/error_goolm.go new file mode 100644 index 00000000..0e54e566 --- /dev/null +++ b/crypto/olm/error_goolm.go @@ -0,0 +1,23 @@ +//go:build goolm + +package olm + +import ( + "errors" + + "maunium.net/go/mautrix/crypto/goolm" +) + +// Error codes from go-olm +var ( + EmptyInput = goolm.ErrEmptyInput + NoKeyProvided = goolm.ErrNoKeyProvided + NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + InputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Error codes from olm code +var ( + UnknownMessageIndex = goolm.ErrRatchetNotAvailable +) diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go deleted file mode 100644 index 9e522b2a..00000000 --- a/crypto/olm/errors.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package olm - -import "errors" - -// Those are the most common used errors -var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") - ErrBadMessageFormat = errors.New("the message couldn't be decoded") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key provided") - ErrBadMessageKeyID = errors.New("the message references an unknown key ID") - ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") - ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") - ErrInputToSmall = errors.New("input too small (truncated?)") -) - -// Error codes from go-olm -var ( - ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") -) - -// Error codes from olm code -var ( - ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") - - ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") - ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") - ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") - ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") - ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") -) - -// Deprecated: use variables prefixed with Err -var ( - EmptyInput = ErrEmptyInput - BadSignature = ErrBadSignature - InvalidBase64 = ErrLibolmInvalidBase64 - BadMessageKeyID = ErrBadMessageKeyID - BadMessageFormat = ErrBadMessageFormat - BadMessageVersion = ErrWrongProtocolVersion - BadMessageMAC = ErrBadMAC - UnknownPickleVersion = ErrUnknownOlmPickleVersion - NotEnoughRandom = ErrLibolmNotEnoughRandom - OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall - BadAccountKey = ErrLibolmBadAccountKey - CorruptedPickle = ErrLibolmCorruptedPickle - BadSessionKey = ErrLibolmBadSessionKey - UnknownMessageIndex = ErrUnknownMessageIndex - BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle - InputBufferTooSmall = ErrInputToSmall - NoKeyProvided = ErrNoKeyProvided - - NotEnoughGoRandom = ErrNotEnoughGoRandom - InputNotJSONString = ErrInputNotJSONString - - ErrBadVersion = ErrUnknownJSONPickleVersion - ErrWrongPickleVersion = ErrUnknownJSONPickleVersion - ErrRatchetNotAvailable = ErrUnknownMessageIndex -) diff --git a/crypto/olm/groupsession_test.go b/crypto/olm/groupsession_test.go deleted file mode 100644 index 0f845e90..00000000 --- a/crypto/olm/groupsession_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package olm_test - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/libolm" -) - -// TestEncryptDecrypt_GoolmToLibolm tests encryption where goolm encrypts and libolm decrypts -func TestEncryptDecrypt_GoolmToLibolm(t *testing.T) { - goolmOutbound, err := session.NewMegolmOutboundSession() - require.NoError(t, err) - - libolmInbound, err := libolm.NewInboundGroupSession([]byte(goolmOutbound.Key())) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - ciphertext, err := goolmOutbound.Encrypt([]byte(fmt.Sprintf("message %d", i))) - require.NoError(t, err) - - plaintext, msgIdx, err := libolmInbound.Decrypt(ciphertext) - assert.NoError(t, err) - assert.Equal(t, []byte(fmt.Sprintf("message %d", i)), plaintext) - assert.Equal(t, goolmOutbound.MessageIndex()-1, msgIdx) - } -} - -func TestEncryptDecrypt_LibolmToGoolm(t *testing.T) { - libolmOutbound, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - goolmInbound, err := session.NewMegolmInboundSession([]byte(libolmOutbound.Key())) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - ciphertext, err := libolmOutbound.Encrypt([]byte(fmt.Sprintf("message %d", i))) - require.NoError(t, err) - - plaintext, msgIdx, err := goolmInbound.Decrypt(ciphertext) - assert.NoError(t, err) - assert.Equal(t, []byte(fmt.Sprintf("message %d", i)), plaintext) - assert.Equal(t, libolmOutbound.MessageIndex()-1, msgIdx) - } -} diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index 8839b48c..cac49d18 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -1,80 +1,305 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !goolm package olm -import "maunium.net/go/mautrix/id" +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" -type InboundGroupSession interface { - // Pickle returns an InboundGroupSession as a base64 string. Encrypts the - // InboundGroupSession using the supplied key. - Pickle(key []byte) ([]byte, error) +import ( + "bytes" + "encoding/base64" + "unsafe" - // Unpickle loads an [InboundGroupSession] from a pickled base64 string. - // Decrypts the [InboundGroupSession] using the supplied key. - Unpickle(pickled, key []byte) error + "maunium.net/go/mautrix/id" +) - // Decrypt decrypts a message using the [InboundGroupSession]. Returns the - // plain-text and message index on success. Returns error on failure. If - // the base64 couldn't be decoded then the error will be "INVALID_BASE64". - // If the message is for an unsupported version of the protocol then the - // error will be "BAD_MESSAGE_VERSION". If the message couldn't be decoded - // then the error will be BAD_MESSAGE_FORMAT". If the MAC on the message - // was invalid then the error will be "BAD_MESSAGE_MAC". If we do not have - // a session key corresponding to the message's index (ie, it was sent - // before the session key was shared with us) the error will be - // "OLM_UNKNOWN_MESSAGE_INDEX". - Decrypt(message []byte) ([]byte, uint, error) - - // ID returns a base64-encoded identifier for this session. - ID() id.SessionID - - // FirstKnownIndex returns the first message index we know how to decrypt. - FirstKnownIndex() uint32 - - // IsVerified check if the session has been verified as a valid session. - // (A session is verified either because the original session share was - // signed, or because we have subsequently successfully decrypted a - // message.) - IsVerified() bool - - // Export returns the base64-encoded ratchet key for this session, at the - // given index, in a format which can be used by - // InboundGroupSession.InboundGroupSessionImport(). Encrypts the - // InboundGroupSession using the supplied key. Returns error on failure. - // if we do not have a session key corresponding to the given index (ie, it - // was sent before the session key was shared with us) the error will be - // "OLM_UNKNOWN_MESSAGE_INDEX". - Export(messageIndex uint32) ([]byte, error) +// InboundGroupSession stores an inbound encrypted messaging session for a +// group. +type InboundGroupSession struct { + int *C.OlmInboundGroupSession + mem []byte } -var InitInboundGroupSessionFromPickled func(pickled, key []byte) (InboundGroupSession, error) -var InitNewInboundGroupSession func(sessionKey []byte) (InboundGroupSession, error) -var InitInboundGroupSessionImport func(sessionKey []byte) (InboundGroupSession, error) -var InitBlankInboundGroupSession func() InboundGroupSession - // InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled -// base64 string. Decrypts the InboundGroupSession using the supplied key. -// Returns error on failure. -func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) { - return InitInboundGroupSessionFromPickled(pickled, key) +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. If the key doesn't match the one used to encrypt +// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". +func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + s := NewBlankInboundGroupSession() + return s, s.Unpickle(pickled, key) } // NewInboundGroupSession creates a new inbound group session from a key -// exported from OutboundGroupSession.Key(). Returns error on failure. -func NewInboundGroupSession(sessionKey []byte) (InboundGroupSession, error) { - return InitNewInboundGroupSession(sessionKey) +// exported from OutboundGroupSession.Key(). Returns error on failure. +// If the sessionKey is not valid base64 the error will be +// "OLM_INVALID_BASE64". If the session_key is invalid the error will be +// "OLM_BAD_SESSION_KEY". +func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, EmptyInput + } + s := NewBlankInboundGroupSession() + r := C.olm_init_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil } // InboundGroupSessionImport imports an inbound group session from a previous -// export. Returns error on failure. -func InboundGroupSessionImport(sessionKey []byte) (InboundGroupSession, error) { - return InitInboundGroupSessionImport(sessionKey) +// export. Returns error on failure. If the sessionKey is not valid base64 +// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the +// error will be "OLM_BAD_SESSION_KEY". +func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, EmptyInput + } + s := NewBlankInboundGroupSession() + r := C.olm_import_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil } -func NewBlankInboundGroupSession() InboundGroupSession { - return InitBlankInboundGroupSession() +// inboundGroupSessionSize is the size of an inbound group session object in +// bytes. +func inboundGroupSessionSize() uint { + return uint(C.olm_inbound_group_session_size()) +} + +// newInboundGroupSession initialises an empty InboundGroupSession. +func NewBlankInboundGroupSession() *InboundGroupSession { + memory := make([]byte, inboundGroupSessionSize()) + return &InboundGroupSession{ + int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to an +// inbound group session. +func (s *InboundGroupSession) lastError() error { + return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) +} + +// Clear clears the memory used to back this InboundGroupSession. +func (s *InboundGroupSession) Clear() error { + r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// pickleLen returns the number of bytes needed to store an inbound group +// session. +func (s *InboundGroupSession) pickleLen() uint { + return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) +} + +// Pickle returns an InboundGroupSession as a base64 string. Encrypts the +// InboundGroupSession using the supplied key. +func (s *InboundGroupSession) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + panic(s.lastError()) + } + return pickled[:r] +} + +func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } else if len(pickled) == 0 { + return EmptyInput + } + r := C.olm_unpickle_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *InboundGroupSession) GobEncode() ([]byte, error) { + pickled := s.Pickle(pickleKey) + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankInboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { + pickled := s.Pickle(pickleKey) + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return InputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankInboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a +// given message could decode to. The actual size could be different due to +// padding. Returns error on failure. If the message base64 couldn't be +// decoded then the error will be "INVALID_BASE64". If the message is for an +// unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be "BAD_MESSAGE_FORMAT". +func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { + if len(message) == 0 { + return 0, EmptyInput + } + // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it + message = bytes.Clone(message) + r := C.olm_group_decrypt_max_plaintext_length( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&message[0]), + C.size_t(len(message))) + if r == errorVal() { + return 0, s.lastError() + } + return uint(r), nil +} + +// Decrypt decrypts a message using the InboundGroupSession. Returns the the +// plain-text and message index on success. Returns error on failure. If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". If the +// message is for an unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the +// error will be "BAD_MESSAGE_MAC". If we do not have a session key +// corresponding to the message's index (ie, it was sent before the session key +// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". +func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { + if len(message) == 0 { + return nil, 0, EmptyInput + } + decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) + if err != nil { + return nil, 0, err + } + messageCopy := make([]byte, len(message)) + copy(messageCopy, message) + plaintext := make([]byte, decryptMaxPlaintextLen) + var messageIndex uint32 + r := C.olm_group_decrypt( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&messageCopy[0]), + C.size_t(len(messageCopy)), + (*C.uint8_t)(&plaintext[0]), + C.size_t(len(plaintext)), + (*C.uint32_t)(&messageIndex)) + if r == errorVal() { + return nil, 0, s.lastError() + } + return plaintext[:r], uint(messageIndex), nil +} + +// sessionIdLen returns the number of bytes needed to store a session ID. +func (s *InboundGroupSession) sessionIdLen() uint { + return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) +} + +// ID returns a base64-encoded identifier for this session. +func (s *InboundGroupSession) ID() id.SessionID { + sessionID := make([]byte, s.sessionIdLen()) + r := C.olm_inbound_group_session_id( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&sessionID[0]), + C.size_t(len(sessionID))) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID[:r]) +} + +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *InboundGroupSession) FirstKnownIndex() uint32 { + return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *InboundGroupSession) IsVerified() uint { + return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) +} + +// exportLen returns the number of bytes needed to export an inbound group +// session. +func (s *InboundGroupSession) exportLen() uint { + return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) +} + +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// "OLM_UNKNOWN_MESSAGE_INDEX". +func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { + key := make([]byte, s.exportLen()) + r := C.olm_export_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&key[0]), + C.size_t(len(key)), + C.uint32_t(messageIndex)) + if r == errorVal() { + return nil, s.lastError() + } + return key[:r], nil } diff --git a/crypto/olm/inboundgroupsession_goolm.go b/crypto/olm/inboundgroupsession_goolm.go new file mode 100644 index 00000000..4e561cf7 --- /dev/null +++ b/crypto/olm/inboundgroupsession_goolm.go @@ -0,0 +1,149 @@ +//go:build goolm + +package olm + +import ( + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/id" +) + +// InboundGroupSession stores an inbound encrypted messaging session for a +// group. +type InboundGroupSession struct { + session.MegolmInboundSession +} + +// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. +func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key) + if err != nil { + return nil, err + } + return &InboundGroupSession{ + MegolmInboundSession: *megolmSession, + }, nil +} + +// NewInboundGroupSession creates a new inbound group session from a key +// exported from OutboundGroupSession.Key(). Returns error on failure. +func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, EmptyInput + } + megolmSession, err := session.NewMegolmInboundSession(sessionKey) + if err != nil { + return nil, err + } + return &InboundGroupSession{ + MegolmInboundSession: *megolmSession, + }, nil +} + +// InboundGroupSessionImport imports an inbound group session from a previous +// export. Returns error on failure. +func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, EmptyInput + } + megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey) + if err != nil { + return nil, err + } + return &InboundGroupSession{ + MegolmInboundSession: *megolmSession, + }, nil +} + +func NewBlankInboundGroupSession() *InboundGroupSession { + return &InboundGroupSession{} +} + +// Clear clears the memory used to back this InboundGroupSession. +func (s *InboundGroupSession) Clear() error { + s.MegolmInboundSession = session.MegolmInboundSession{} + return nil +} + +// Pickle returns an InboundGroupSession as a base64 string. Encrypts the +// InboundGroupSession using the supplied key. +func (s *InboundGroupSession) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled, err := s.MegolmInboundSession.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } else if len(pickled) == 0 { + return EmptyInput + } + sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key) + if err != nil { + return err + } + s.MegolmInboundSession = *sOlm + return nil +} + +// Decrypt decrypts a message using the InboundGroupSession. Returns the the +// plain-text and message index on success. Returns error on failure. +func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { + if len(message) == 0 { + return nil, 0, EmptyInput + } + plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message) + if err != nil { + return nil, 0, err + } + return plaintext, uint(messageIndex), nil +} + +// ID returns a base64-encoded identifier for this session. +func (s *InboundGroupSession) ID() id.SessionID { + return s.MegolmInboundSession.SessionID() +} + +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *InboundGroupSession) FirstKnownIndex() uint32 { + return s.MegolmInboundSession.InitialRatchet.Counter +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *InboundGroupSession) IsVerified() uint { + if s.MegolmInboundSession.SigningKeyVerified { + return 1 + } + return 0 +} + +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// returned. +func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { + res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex) + if err != nil { + return nil, err + } + return res, nil +} diff --git a/crypto/olm/olm.go b/crypto/olm/olm.go index fa2345e1..fa1ae856 100644 --- a/crypto/olm/olm.go +++ b/crypto/olm/olm.go @@ -1,20 +1,28 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !goolm package olm -var GetVersion func() (major, minor, patch uint8) -var SetPickleKeyImpl func(key []byte) +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { - return GetVersion() + C.olm_get_library_version( + (*C.uint8_t)(&major), + (*C.uint8_t)(&minor), + (*C.uint8_t)(&patch)) + return } +// errorVal returns the value that olm functions return if there was an error. +func errorVal() C.size_t { + return C.olm_error() +} + +var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") + // SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. func SetPickleKey(key []byte) { - SetPickleKeyImpl(key) + pickleKey = key } diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go new file mode 100644 index 00000000..a1489ded --- /dev/null +++ b/crypto/olm/olm_goolm.go @@ -0,0 +1,13 @@ +//go:build goolm + +package olm + +// Version returns the version number of the olm library. +func Version() (major, minor, patch uint8) { + return 3, 2, 15 +} + +// SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. +func SetPickleKey(key []byte) { + panic("gob and json encoding is deprecated and not supported with goolm") +} diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index 7e582b7e..b6a33d36 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -1,57 +1,239 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !goolm package olm -import "maunium.net/go/mautrix/id" +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" -type OutboundGroupSession interface { - // Pickle returns a Session as a base64 string. Encrypts the Session using - // the supplied key. - Pickle(key []byte) ([]byte, error) +import ( + "crypto/rand" + "encoding/base64" + "unsafe" - // Unpickle loads an [OutboundGroupSession] from a pickled base64 string. - // Decrypts the [OutboundGroupSession] using the supplied key. - Unpickle(pickled, key []byte) error + "maunium.net/go/mautrix/id" +) - // Encrypt encrypts a message using the [OutboundGroupSession]. Returns the - // encrypted message as base64. - Encrypt(plaintext []byte) ([]byte, error) - - // ID returns a base64-encoded identifier for this session. - ID() id.SessionID - - // MessageIndex returns the message index for this session. Each message - // is sent with an increasing index; this returns the index for the next - // message. - MessageIndex() uint - - // Key returns the base64-encoded current ratchet key for this session. - Key() string +// OutboundGroupSession stores an outbound encrypted messaging session for a +// group. +type OutboundGroupSession struct { + int *C.OlmOutboundGroupSession + mem []byte } -var InitNewOutboundGroupSessionFromPickled func(pickled, key []byte) (OutboundGroupSession, error) -var InitNewOutboundGroupSession func() (OutboundGroupSession, error) -var InitNewBlankOutboundGroupSession func() OutboundGroupSession - // OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled // base64 string. Decrypts the OutboundGroupSession using the supplied key. // Returns error on failure. If the key doesn't match the one used to encrypt // the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the // base64 couldn't be decoded then the error will be "INVALID_BASE64". -func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) { - return InitNewOutboundGroupSessionFromPickled(pickled, key) +func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + s := NewBlankOutboundGroupSession() + return s, s.Unpickle(pickled, key) } // NewOutboundGroupSession creates a new outbound group session. -func NewOutboundGroupSession() (OutboundGroupSession, error) { - return InitNewOutboundGroupSession() +func NewOutboundGroupSession() *OutboundGroupSession { + s := NewBlankOutboundGroupSession() + random := make([]byte, s.createRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(NotEnoughGoRandom) + } + r := C.olm_init_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + panic(s.lastError()) + } + return s } -// NewBlankOutboundGroupSession initialises an empty [OutboundGroupSession]. -func NewBlankOutboundGroupSession() OutboundGroupSession { - return InitNewBlankOutboundGroupSession() +// outboundGroupSessionSize is the size of an outbound group session object in +// bytes. +func outboundGroupSessionSize() uint { + return uint(C.olm_outbound_group_session_size()) +} + +// newOutboundGroupSession initialises an empty OutboundGroupSession. +func NewBlankOutboundGroupSession() *OutboundGroupSession { + memory := make([]byte, outboundGroupSessionSize()) + return &OutboundGroupSession{ + int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to an +// outbound group session. +func (s *OutboundGroupSession) lastError() error { + return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) +} + +// Clear clears the memory used to back this OutboundGroupSession. +func (s *OutboundGroupSession) Clear() error { + r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) + if r == errorVal() { + return s.lastError() + } else { + return nil + } +} + +// pickleLen returns the number of bytes needed to store an outbound group +// session. +func (s *OutboundGroupSession) pickleLen() uint { + return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the +// OutboundGroupSession using the supplied key. +func (s *OutboundGroupSession) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + panic(s.lastError()) + } + return pickled[:r] +} + +func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } + r := C.olm_unpickle_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *OutboundGroupSession) GobEncode() ([]byte, error) { + pickled := s.Pickle(pickleKey) + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankOutboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { + pickled := s.Pickle(pickleKey) + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return InputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankOutboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// createRandomLen returns the number of random bytes needed to create an +// Account. +func (s *OutboundGroupSession) createRandomLen() uint { + return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// encryptMsgLen returns the size of the next message in bytes for the given +// number of plain-text bytes. +func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { + return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { + if len(plaintext) == 0 { + panic(EmptyInput) + } + message := make([]byte, s.encryptMsgLen(len(plaintext))) + r := C.olm_group_encrypt( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&plaintext[0]), + C.size_t(len(plaintext)), + (*C.uint8_t)(&message[0]), + C.size_t(len(message))) + if r == errorVal() { + panic(s.lastError()) + } + return message[:r] +} + +// sessionIdLen returns the number of bytes needed to store a session ID. +func (s *OutboundGroupSession) sessionIdLen() uint { + return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// ID returns a base64-encoded identifier for this session. +func (s *OutboundGroupSession) ID() id.SessionID { + sessionID := make([]byte, s.sessionIdLen()) + r := C.olm_outbound_group_session_id( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&sessionID[0]), + C.size_t(len(sessionID))) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID[:r]) +} + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *OutboundGroupSession) MessageIndex() uint { + return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) +} + +// sessionKeyLen returns the number of bytes needed to store a session key. +func (s *OutboundGroupSession) sessionKeyLen() uint { + return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *OutboundGroupSession) Key() string { + sessionKey := make([]byte, s.sessionKeyLen()) + r := C.olm_outbound_group_session_key( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) + if r == errorVal() { + panic(s.lastError()) + } + return string(sessionKey[:r]) } diff --git a/crypto/olm/outboundgroupsession_goolm.go b/crypto/olm/outboundgroupsession_goolm.go new file mode 100644 index 00000000..7c201213 --- /dev/null +++ b/crypto/olm/outboundgroupsession_goolm.go @@ -0,0 +1,111 @@ +//go:build goolm + +package olm + +import ( + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/id" +) + +// OutboundGroupSession stores an outbound encrypted messaging session for a +// group. +type OutboundGroupSession struct { + session.MegolmOutboundSession +} + +// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled +// base64 string. Decrypts the OutboundGroupSession using the supplied key. +// Returns error on failure. If the key doesn't match the one used to encrypt +// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". +func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key) + if err != nil { + return nil, err + } + return &OutboundGroupSession{ + MegolmOutboundSession: *megolmSession, + }, nil +} + +// NewOutboundGroupSession creates a new outbound group session. +func NewOutboundGroupSession() *OutboundGroupSession { + megolmSession, err := session.NewMegolmOutboundSession() + if err != nil { + panic(err) + } + return &OutboundGroupSession{ + MegolmOutboundSession: *megolmSession, + } +} + +// newOutboundGroupSession initialises an empty OutboundGroupSession. +func NewBlankOutboundGroupSession() *OutboundGroupSession { + return &OutboundGroupSession{} +} + +// Clear clears the memory used to back this OutboundGroupSession. +func (s *OutboundGroupSession) Clear() error { + s.MegolmOutboundSession = session.MegolmOutboundSession{} + return nil +} + +// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the +// OutboundGroupSession using the supplied key. +func (s *OutboundGroupSession) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled, err := s.MegolmOutboundSession.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } + return s.MegolmOutboundSession.Unpickle(pickled, key) +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { + if len(plaintext) == 0 { + panic(EmptyInput) + } + message, err := s.MegolmOutboundSession.Encrypt(plaintext) + if err != nil { + panic(err) + } + return message +} + +// ID returns a base64-encoded identifier for this session. +func (s *OutboundGroupSession) ID() id.SessionID { + return s.MegolmOutboundSession.SessionID() +} + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *OutboundGroupSession) MessageIndex() uint { + return uint(s.MegolmOutboundSession.Ratchet.Counter) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *OutboundGroupSession) Key() string { + message, err := s.MegolmOutboundSession.SessionSharingMessage() + if err != nil { + panic(err) + } + return string(message) +} diff --git a/crypto/olm/outboundgroupsession_test.go b/crypto/olm/outboundgroupsession_test.go deleted file mode 100644 index cbbc89f7..00000000 --- a/crypto/olm/outboundgroupsession_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package olm_test - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/libolm" -) - -func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) { - libolmSession, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - libolmPickled, err := libolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - goolmSession, err := session.MegolmOutboundSessionFromPickled(libolmPickled, []byte("test")) - require.NoError(t, err) - - goolmPickled, err := goolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") - - libolmSession2, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - err = libolmSession2.Unpickle(bytes.Clone(goolmPickled), []byte("test")) - require.NoError(t, err) - - assert.Equal(t, libolmSession.Key(), libolmSession2.Key()) -} - -func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) { - goolmSession, err := session.NewMegolmOutboundSession() - require.NoError(t, err) - - goolmPickled, err := goolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - libolmSession, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) - require.NoError(t, err) - - libolmPickled, err := libolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - assert.Equal(t, goolmPickled, libolmPickled, "pickled versions are not the same") - - goolmSession2, err := session.MegolmOutboundSessionFromPickled(libolmPickled, []byte("test")) - require.NoError(t, err) - - assert.Equal(t, goolmSession.Key(), goolmSession2.Key()) - assert.Equal(t, goolmSession.SigningKey.PrivateKey, goolmSession2.SigningKey.PrivateKey) -} - -func TestMegolmOutboundSessionPickleLibolm(t *testing.T) { - libolmSession, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - libolmPickled, err := libolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - goolmSession, err := session.MegolmOutboundSessionFromPickled(bytes.Clone(libolmPickled), []byte("test")) - require.NoError(t, err) - goolmPickled, err := goolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") - assert.Equal(t, goolmSession.SigningKey.PrivateKey.PubKey(), goolmSession.SigningKey.PublicKey) - - // Ensure that the key export is the same and that the pickle is the same - assert.Equal(t, libolmSession.Key(), goolmSession.Key(), "keys are not the same") -} - -func TestMegolmOutboundSessionPickleGoolm(t *testing.T) { - goolmSession, err := session.NewMegolmOutboundSession() - require.NoError(t, err) - goolmPickled, err := goolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - libolmSession, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) - require.NoError(t, err) - libolmPickled, err := libolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") - assert.Equal(t, goolmSession.SigningKey.PrivateKey.PubKey(), goolmSession.SigningKey.PublicKey) - - // Ensure that the key export is the same and that the pickle is the same - assert.Equal(t, libolmSession.Key(), goolmSession.Key(), "keys are not the same") -} - -func FuzzMegolmOutboundSession_Encrypt(f *testing.F) { - f.Add([]byte("anything")) - - f.Fuzz(func(t *testing.T, plaintext []byte) { - if len(plaintext) == 0 { - t.Skip("empty plaintext is not supported") - } - - libolmSession, err := libolm.NewOutboundGroupSession() - require.NoError(t, err) - libolmPickled, err := libolmSession.Pickle([]byte("test")) - require.NoError(t, err) - - goolmSession, err := session.MegolmOutboundSessionFromPickled(bytes.Clone(libolmPickled), []byte("test")) - require.NoError(t, err) - - assert.Equal(t, libolmSession.Key(), goolmSession.Key()) - - // Encrypt the plaintext ten times because the ratchet increments. - for i := 0; i < 10; i++ { - assert.EqualValues(t, i, libolmSession.MessageIndex()) - assert.EqualValues(t, i, goolmSession.MessageIndex()) - - libolmEncrypted, err := libolmSession.Encrypt(plaintext) - require.NoError(t, err) - - goolmEncrypted, err := goolmSession.Encrypt(plaintext) - require.NoError(t, err) - - assert.Equal(t, libolmEncrypted, goolmEncrypted) - - assert.EqualValues(t, i+1, libolmSession.MessageIndex()) - assert.EqualValues(t, i+1, goolmSession.MessageIndex()) - } - }) -} diff --git a/crypto/olm/pk_goolm.go b/crypto/olm/pk_goolm.go new file mode 100644 index 00000000..372c94fa --- /dev/null +++ b/crypto/olm/pk_goolm.go @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// When the goolm build flag is enabled, this file will make [PKSigning] +// constructors use the goolm constuctors. + +//go:build goolm + +package olm + +import "maunium.net/go/mautrix/crypto/goolm/pk" + +// NewPKSigningFromSeed creates a new PKSigning object using the given seed. +func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { + return pk.NewSigningFromSeed(seed) +} + +// NewPKSigning creates a new [PKSigning] object, containing a key pair for +// signing messages. +func NewPKSigning() (PKSigning, error) { + return pk.NewSigning() +} + +func NewPKDecryption(privateKey []byte) (PKDecryption, error) { + return pk.NewDecryption() +} diff --git a/crypto/olm/pk.go b/crypto/olm/pk_interface.go similarity index 52% rename from crypto/olm/pk.go rename to crypto/olm/pk_interface.go index 70ee452d..11c41431 100644 --- a/crypto/olm/pk.go +++ b/crypto/olm/pk_interface.go @@ -7,6 +7,7 @@ package olm import ( + "maunium.net/go/mautrix/crypto/goolm/pk" "maunium.net/go/mautrix/id" ) @@ -26,32 +27,15 @@ type PKSigning interface { SignJSON(obj any) (string, error) } +var _ PKSigning = (*pk.Signing)(nil) + // PKDecryption is an interface for decrypting messages. type PKDecryption interface { // PublicKey returns the public key. PublicKey() id.Curve25519 // Decrypt verifies and decrypts the given message. - Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) + Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) } -var InitNewPKSigning func() (PKSigning, error) -var InitNewPKSigningFromSeed func(seed []byte) (PKSigning, error) -var InitNewPKDecryptionFromPrivateKey func(privateKey []byte) (PKDecryption, error) - -// NewPKSigning creates a new [PKSigning] object, containing a key pair for -// signing messages. -func NewPKSigning() (PKSigning, error) { - return InitNewPKSigning() -} - -// NewPKSigningFromSeed creates a new PKSigning object using the given seed. -func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { - return InitNewPKSigningFromSeed(seed) -} - -// NewPKDecryptionFromPrivateKey creates a new [PKDecryption] from a -// base64-encoded private key. -func NewPKDecryptionFromPrivateKey(privateKey []byte) (PKDecryption, error) { - return InitNewPKDecryptionFromPrivateKey(privateKey) -} +var _ PKDecryption = (*pk.Decryption)(nil) diff --git a/crypto/libolm/pk.go b/crypto/olm/pk_libolm.go similarity index 52% rename from crypto/libolm/pk.go rename to crypto/olm/pk_libolm.go index 2683cf15..0854b4d1 100644 --- a/crypto/libolm/pk.go +++ b/crypto/olm/pk_libolm.go @@ -4,7 +4,9 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package libolm +//go:build !goolm + +package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include @@ -14,26 +16,24 @@ import "C" import ( "crypto/rand" "encoding/json" - "runtime" "unsafe" "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) -// PKSigning stores a key pair for signing messages. -type PKSigning struct { +// LibOlmPKSigning stores a key pair for signing messages. +type LibOlmPKSigning struct { int *C.OlmPkSigning mem []byte publicKey id.Ed25519 seed []byte } -// Ensure that [PKSigning] implements [olm.PKSigning]. -var _ olm.PKSigning = (*PKSigning)(nil) +// Ensure that LibOlmPKSigning implements PKSigning. +var _ PKSigning = (*LibOlmPKSigning)(nil) func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) @@ -51,27 +51,22 @@ func pkSigningSignatureLength() uint { return uint(C.olm_pk_signature_length()) } -func newBlankPKSigning() *PKSigning { +func newBlankPKSigning() *LibOlmPKSigning { memory := make([]byte, pkSigningSize()) - return &PKSigning{ - int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))), + return &LibOlmPKSigning{ + int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), mem: memory, } } // NewPKSigningFromSeed creates a new [PKSigning] object using the given seed. -func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { +func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { p := newBlankPKSigning() p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) - r := C.olm_pk_signing_key_from_seed( - (*C.OlmPkSigning)(p.int), - unsafe.Pointer(unsafe.SliceData(pubKey)), - C.size_t(len(pubKey)), - unsafe.Pointer(unsafe.SliceData(seed)), - C.size_t(len(seed)), - ) - if r == errorVal() { + if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), + unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), + unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { return nil, p.lastError() } p.publicKey = id.Ed25519(pubKey) @@ -79,51 +74,44 @@ func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { return p, nil } -// NewPKSigning creates a new [PKSigning] object, containing a key pair for +// NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for // signing messages. -func NewPKSigning() (*PKSigning, error) { +func NewPKSigning() (PKSigning, error) { // Generate the seed seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(olm.ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err } -func (p *PKSigning) PublicKey() id.Ed25519 { +func (p *LibOlmPKSigning) PublicKey() id.Ed25519 { return p.publicKey } -func (p *PKSigning) Seed() []byte { +func (p *LibOlmPKSigning) Seed() []byte { return p.seed } -// clear clears the underlying memory of a [PKSigning] object. -func (p *PKSigning) clear() { +// clear clears the underlying memory of a LibOlmPKSigning object. +func (p *LibOlmPKSigning) clear() { C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) } // Sign creates a signature for the given message using this key. -func (p *PKSigning) Sign(message []byte) ([]byte, error) { +func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) - r := C.olm_pk_sign( - (*C.OlmPkSigning)(p.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), - C.size_t(len(message)), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))), - C.size_t(len(signature)), - ) - runtime.KeepAlive(message) - if r == errorVal() { + if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), + (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { return nil, p.lastError() } return signature, nil } // SignJSON creates a signature for the given object after encoding it to canonical JSON. -func (p *PKSigning) SignJSON(obj interface{}) (string, error) { +func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) { objJSON, err := json.Marshal(obj) if err != nil { return "", err @@ -138,15 +126,15 @@ func (p *PKSigning) SignJSON(obj interface{}) (string, error) { } // lastError returns the last error that happened in relation to this -// [PKSigning] object. -func (p *PKSigning) lastError() error { +// LibOlmPKSigning object. +func (p *LibOlmPKSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } -type PKDecryption struct { +type LibOlmPKDecryption struct { int *C.OlmPkDecryption mem []byte - publicKey []byte + PublicKey []byte } func pkDecryptionSize() uint { @@ -157,56 +145,34 @@ func pkDecryptionPublicKeySize() uint { return uint(C.olm_pk_key_length()) } -func NewPkDecryption(privateKey []byte) (*PKDecryption, error) { +func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) { memory := make([]byte, pkDecryptionSize()) - p := &PKDecryption{ - int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))), + p := &LibOlmPKDecryption{ + int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), mem: memory, } p.clear() pubKey := make([]byte, pkDecryptionPublicKeySize()) - r := C.olm_pk_key_from_private( - (*C.OlmPkDecryption)(p.int), - unsafe.Pointer(unsafe.SliceData(pubKey)), - C.size_t(len(pubKey)), - unsafe.Pointer(unsafe.SliceData(privateKey)), - C.size_t(len(privateKey)), - ) - runtime.KeepAlive(privateKey) - if r == errorVal() { + if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), + unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { return nil, p.lastError() } - p.publicKey = pubKey + p.PublicKey = pubKey return p, nil } -func (p *PKDecryption) PublicKey() id.Curve25519 { - return id.Curve25519(p.publicKey) -} - -func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { - maxPlaintextLength := uint(C.olm_pk_max_plaintext_length( - (*C.OlmPkDecryption)(p.int), - C.size_t(len(ciphertext)), - )) +func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) plaintext := make([]byte, maxPlaintextLength) - size := C.olm_pk_decrypt( - (*C.OlmPkDecryption)(p.int), - unsafe.Pointer(unsafe.SliceData(ephemeralKey)), - C.size_t(len(ephemeralKey)), - unsafe.Pointer(unsafe.SliceData(mac)), - C.size_t(len(mac)), - unsafe.Pointer(unsafe.SliceData(ciphertext)), - C.size_t(len(ciphertext)), - unsafe.Pointer(unsafe.SliceData(plaintext)), - C.size_t(len(plaintext)), - ) - runtime.KeepAlive(ephemeralKey) - runtime.KeepAlive(mac) - runtime.KeepAlive(ciphertext) + size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), + unsafe.Pointer(&mac[0]), C.size_t(len(mac)), + unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), + unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) if size == errorVal() { return nil, p.lastError() } @@ -215,12 +181,12 @@ func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byt } // Clear clears the underlying memory of a PkDecryption object. -func (p *PKDecryption) clear() { +func (p *LibOlmPKDecryption) clear() { C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) } // lastError returns the last error that happened in relation to this -// [PKDecryption] object. -func (p *PKDecryption) lastError() error { +// LibOlmPKDecryption object. +func (p *LibOlmPKDecryption) lastError() error { return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) } diff --git a/crypto/olm/pk_test.go b/crypto/olm/pk_test.go index 99ac1e6b..b57e6571 100644 --- a/crypto/olm/pk_test.go +++ b/crypto/olm/pk_test.go @@ -4,7 +4,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -// Only run this test if goolm is disabled (that is, libolm is used). +// Only run this test if goo is disabled (that is, libolm is used). +//go:build !goolm package olm_test @@ -15,7 +16,7 @@ import ( "github.com/stretchr/testify/require" "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/crypto/libolm" + "maunium.net/go/mautrix/crypto/olm" ) func FuzzSign(f *testing.F) { @@ -23,7 +24,7 @@ func FuzzSign(f *testing.F) { goolmPkSigning, err := pk.NewSigningFromSeed(seed) require.NoError(f, err) - libolmPkSigning, err := libolm.NewPKSigningFromSeed(seed) + libolmPkSigning, err := olm.NewPKSigningFromSeed(seed) require.NoError(f, err) f.Add([]byte("message")) diff --git a/crypto/olm/session.go b/crypto/olm/session.go index c4b91ffc..185e0b3d 100644 --- a/crypto/olm/session.go +++ b/crypto/olm/session.go @@ -1,83 +1,362 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !goolm package olm -import "maunium.net/go/mautrix/id" +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +// #include +// #include +// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); +// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { +// if (olm_session_describe) { +// olm_session_describe(session, buf, buflen); +// } else { +// sprintf(buf, "olm_session_describe not supported"); +// } +// } +import "C" -type Session interface { - // Pickle returns a Session as a base64 string. Encrypts the Session using - // the supplied key. - Pickle(key []byte) ([]byte, error) +import ( + "crypto/rand" + "encoding/base64" + "unsafe" - // Unpickle loads a Session from a pickled base64 string. Decrypts the - // Session using the supplied key. - Unpickle(pickled, key []byte) error + "maunium.net/go/mautrix/id" +) - // ID returns an identifier for this Session. Will be the same for both - // ends of the conversation. - ID() id.SessionID - - // HasReceivedMessage returns true if this session has received any - // message. - HasReceivedMessage() bool - - // MatchesInboundSession checks if the PRE_KEY message is for this in-bound - // Session. This can happen if multiple messages are sent to this Account - // before this Account sends a message in reply. Returns true if the - // session matches. Returns false if the session does not match. Returns - // error on failure. If the base64 couldn't be decoded then the error will - // be "INVALID_BASE64". If the message was for an unsupported protocol - // version then the error will be "BAD_MESSAGE_VERSION". If the message - // couldn't be decoded then then the error will be "BAD_MESSAGE_FORMAT". - MatchesInboundSession(oneTimeKeyMsg string) (bool, error) - - // MatchesInboundSessionFrom checks if the PRE_KEY message is for this - // in-bound Session. This can happen if multiple messages are sent to this - // Account before this Account sends a message in reply. Returns true if - // the session matches. Returns false if the session does not match. - // Returns error on failure. If the base64 couldn't be decoded then the - // error will be "INVALID_BASE64". If the message was for an unsupported - // protocol version then the error will be "BAD_MESSAGE_VERSION". If the - // message couldn't be decoded then then the error will be - // "BAD_MESSAGE_FORMAT". - MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) - - // EncryptMsgType returns the type of the next message that Encrypt will - // return. Returns MsgTypePreKey if the message will be a PRE_KEY message. - // Returns MsgTypeMsg if the message will be a normal message. - EncryptMsgType() id.OlmMsgType - - // Encrypt encrypts a message using the Session. Returns the encrypted - // message as base64. - Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) - - // Decrypt decrypts a message using the Session. Returns the plain-text on - // success. Returns error on failure. If the base64 couldn't be decoded - // then the error will be "INVALID_BASE64". If the message is for an - // unsupported version of the protocol then the error will be - // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error - // will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then - // the error will be "BAD_MESSAGE_MAC". - Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) - - // Describe generates a string describing the internal state of an olm - // session for debugging and logging purposes. - Describe() string +// Session stores an end to end encrypted messaging session. +type Session struct { + int *C.OlmSession + mem []byte } -var InitSessionFromPickled func(pickled, key []byte) (Session, error) -var InitNewBlankSession func() Session +// sessionSize is the size of a session object in bytes. +func sessionSize() uint { + return uint(C.olm_session_size()) +} // SessionFromPickled loads a Session from a pickled base64 string. Decrypts -// the Session using the supplied key. Returns error on failure. -func SessionFromPickled(pickled, key []byte) (Session, error) { - return InitSessionFromPickled(pickled, key) +// the Session using the supplied key. Returns error on failure. If the key +// doesn't match the one used to encrypt the Session then the error will be +// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". +func SessionFromPickled(pickled, key []byte) (*Session, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + s := NewBlankSession() + return s, s.Unpickle(pickled, key) } -func NewBlankSession() Session { - return InitNewBlankSession() +func NewBlankSession() *Session { + memory := make([]byte, sessionSize()) + return &Session{ + int: C.olm_session(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to a +// session. +func (s *Session) lastError() error { + return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) +} + +// Clear clears the memory used to back this Session. +func (s *Session) Clear() error { + r := C.olm_clear_session((*C.OlmSession)(s.int)) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// pickleLen returns the number of bytes needed to store a session. +func (s *Session) pickleLen() uint { + return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) +} + +// createOutboundRandomLen returns the number of random bytes needed to create +// an outbound session. +func (s *Session) createOutboundRandomLen() uint { + return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) +} + +// idLen returns the length of the buffer needed to return the id for this +// session. +func (s *Session) idLen() uint { + return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) +} + +// encryptRandomLen returns the number of random bytes needed to encrypt the +// next message. +func (s *Session) encryptRandomLen() uint { + return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) +} + +// encryptMsgLen returns the size of the next message in bytes for the given +// number of plain-text bytes. +func (s *Session) encryptMsgLen(plainTextLen int) uint { + return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) +} + +// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a +// given message could decode to. The actual size could be different due to +// padding. Returns error on failure. If the message base64 couldn't be +// decoded then the error will be "INVALID_BASE64". If the message is for an +// unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be "BAD_MESSAGE_FORMAT". +func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { + if len(message) == 0 { + return 0, EmptyInput + } + r := C.olm_decrypt_max_plaintext_length( + (*C.OlmSession)(s.int), + C.size_t(msgType), + unsafe.Pointer(C.CString(message)), + C.size_t(len(message))) + if r == errorVal() { + return 0, s.lastError() + } + return uint(r), nil +} + +// Pickle returns a Session as a base64 string. Encrypts the Session using the +// supplied key. +func (s *Session) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + panic(s.lastError()) + } + return pickled[:r] +} + +func (s *Session) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } + r := C.olm_unpickle_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *Session) GobEncode() ([]byte, error) { + pickled := s.Pickle(pickleKey) + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *Session) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *Session) MarshalJSON() ([]byte, error) { + pickled := s.Pickle(pickleKey) + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *Session) UnmarshalJSON(data []byte) error { + if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return InputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// Id returns an identifier for this Session. Will be the same for both ends +// of the conversation. +func (s *Session) ID() id.SessionID { + sessionID := make([]byte, s.idLen()) + r := C.olm_session_id( + (*C.OlmSession)(s.int), + unsafe.Pointer(&sessionID[0]), + C.size_t(len(sessionID))) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID) +} + +// HasReceivedMessage returns true if this session has received any message. +func (s *Session) HasReceivedMessage() bool { + switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { + case 0: + return false + default: + return true + } +} + +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". If the message was for an unsupported protocol version +// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be +// decoded then then the error will be "BAD_MESSAGE_FORMAT". +func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + if len(oneTimeKeyMsg) == 0 { + return false, EmptyInput + } + r := C.olm_matches_inbound_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), + C.size_t(len(oneTimeKeyMsg))) + if r == 1 { + return true, nil + } else if r == 0 { + return false, nil + } else { // if r == errorVal() + return false, s.lastError() + } +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". If the message was for an unsupported protocol version +// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be +// decoded then then the error will be "BAD_MESSAGE_FORMAT". +func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { + return false, EmptyInput + } + r := C.olm_matches_inbound_session_from( + (*C.OlmSession)(s.int), + unsafe.Pointer(&([]byte(theirIdentityKey))[0]), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), + C.size_t(len(oneTimeKeyMsg))) + if r == 1 { + return true, nil + } else if r == 0 { + return false, nil + } else { // if r == errorVal() + return false, s.lastError() + } +} + +// EncryptMsgType returns the type of the next message that Encrypt will +// return. Returns MsgTypePreKey if the message will be a PRE_KEY message. +// Returns MsgTypeMsg if the message will be a normal message. Returns error +// on failure. +func (s *Session) EncryptMsgType() id.OlmMsgType { + switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { + case C.size_t(id.OlmMsgTypePreKey): + return id.OlmMsgTypePreKey + case C.size_t(id.OlmMsgTypeMsg): + return id.OlmMsgTypeMsg + default: + panic("olm_encrypt_message_type returned invalid result") + } +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { + if len(plaintext) == 0 { + panic(EmptyInput) + } + // Make the slice be at least length 1 + random := make([]byte, s.encryptRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(NotEnoughGoRandom) + } + messageType := s.EncryptMsgType() + message := make([]byte, s.encryptMsgLen(len(plaintext))) + r := C.olm_encrypt( + (*C.OlmSession)(s.int), + unsafe.Pointer(&plaintext[0]), + C.size_t(len(plaintext)), + unsafe.Pointer(&random[0]), + C.size_t(len(random)), + unsafe.Pointer(&message[0]), + C.size_t(len(message))) + if r == errorVal() { + panic(s.lastError()) + } + return messageType, message[:r] +} + +// Decrypt decrypts a message using the Session. Returns the the plain-text on +// success. Returns error on failure. If the base64 couldn't be decoded then +// the error will be "INVALID_BASE64". If the message is for an unsupported +// version of the protocol then the error will be "BAD_MESSAGE_VERSION". If +// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". +// If the MAC on the message was invalid then the error will be +// "BAD_MESSAGE_MAC". +func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { + if len(message) == 0 { + return nil, EmptyInput + } + decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) + if err != nil { + return nil, err + } + messageCopy := []byte(message) + plaintext := make([]byte, decryptMaxPlaintextLen) + r := C.olm_decrypt( + (*C.OlmSession)(s.int), + C.size_t(msgType), + unsafe.Pointer(&(messageCopy)[0]), + C.size_t(len(messageCopy)), + unsafe.Pointer(&plaintext[0]), + C.size_t(len(plaintext))) + if r == errorVal() { + return nil, s.lastError() + } + return plaintext[:r], nil +} + +// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 +const maxDescribeSize = 600 + +// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. +func (s *Session) Describe() string { + desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) + defer C.free(unsafe.Pointer(desc)) + C.meowlm_session_describe( + (*C.OlmSession)(s.int), + desc, + C.size_t(maxDescribeSize)) + return C.GoString(desc) } diff --git a/crypto/olm/session_goolm.go b/crypto/olm/session_goolm.go new file mode 100644 index 00000000..c77efaa2 --- /dev/null +++ b/crypto/olm/session_goolm.go @@ -0,0 +1,110 @@ +//go:build goolm + +package olm + +import ( + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/id" +) + +// Session stores an end to end encrypted messaging session. +type Session struct { + session.OlmSession +} + +// SessionFromPickled loads a Session from a pickled base64 string. Decrypts +// the Session using the supplied key. Returns error on failure. +func SessionFromPickled(pickled, key []byte) (*Session, error) { + if len(pickled) == 0 { + return nil, EmptyInput + } + s := NewBlankSession() + return s, s.Unpickle(pickled, key) +} + +func NewBlankSession() *Session { + return &Session{} +} + +// Clear clears the memory used to back this Session. +func (s *Session) Clear() error { + s.OlmSession = session.OlmSession{} + return nil +} + +// Pickle returns a Session as a base64 string. Encrypts the Session using the +// supplied key. +func (s *Session) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(NoKeyProvided) + } + pickled, err := s.OlmSession.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (s *Session) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return NoKeyProvided + } else if len(pickled) == 0 { + return EmptyInput + } + sOlm, err := session.OlmSessionFromPickled(pickled, key) + if err != nil { + return err + } + s.OlmSession = *sOlm + return nil +} + +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + return s.MatchesInboundSessionFrom("", oneTimeKeyMsg) +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + if theirIdentityKey != "" { + theirKey := id.Curve25519(theirIdentityKey) + return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg)) + } + return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg)) + +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { + if len(plaintext) == 0 { + panic(EmptyInput) + } + messageType, message, err := s.OlmSession.Encrypt(plaintext, nil) + if err != nil { + panic(err) + } + return messageType, message +} + +// Decrypt decrypts a message using the Session. Returns the the plain-text on +// success. Returns error on failure. +func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { + if len(message) == 0 { + return nil, EmptyInput + } + return s.OlmSession.Decrypt([]byte(message), msgType) +} + +// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. +func (s *Session) Describe() string { + return s.OlmSession.Describe() +} diff --git a/crypto/olm/session_test.go b/crypto/olm/session_test.go deleted file mode 100644 index b0b9896f..00000000 --- a/crypto/olm/session_test.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package olm_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.mau.fi/util/exerrors" - "golang.org/x/exp/maps" - - "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/libolm" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/id" -) - -func TestBlankSession(t *testing.T) { - libolmSession := libolm.NewBlankSession() - session := session.NewOlmSession() - - assert.Equal(t, libolmSession.ID(), session.ID()) - assert.Equal(t, libolmSession.HasReceivedMessage(), session.HasReceivedMessage()) - assert.Equal(t, libolmSession.EncryptMsgType(), session.EncryptMsgType()) - assert.Equal(t, libolmSession.Describe(), session.Describe()) - - libolmPickled, err := libolmSession.Pickle([]byte("test")) - assert.NoError(t, err) - goolmPickled, err := session.Pickle([]byte("test")) - assert.NoError(t, err) - assert.Equal(t, goolmPickled, libolmPickled) -} - -func TestSessionPickle(t *testing.T) { - pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") - pickleKey := []byte("secret_key") - - goolmSession, err := session.OlmSessionFromPickled(bytes.Clone(pickledDataFromLibOlm), pickleKey) - assert.NoError(t, err) - - libolmSession, err := libolm.SessionFromPickled(bytes.Clone(pickledDataFromLibOlm), pickleKey) - assert.NoError(t, err) - - goolmPickled, err := goolmSession.Pickle(pickleKey) - require.NoError(t, err) - assert.Equal(t, pickledDataFromLibOlm, goolmPickled) - - libolmPickled, err := libolmSession.Pickle(pickleKey) - require.NoError(t, err) - assert.Equal(t, pickledDataFromLibOlm, libolmPickled) -} - -func TestSession_EncryptDecrypt(t *testing.T) { - combos := [][2]olm.Account{ - {exerrors.Must(libolm.NewAccount()), exerrors.Must(libolm.NewAccount())}, - {exerrors.Must(account.NewAccount()), exerrors.Must(account.NewAccount())}, - {exerrors.Must(libolm.NewAccount()), exerrors.Must(account.NewAccount())}, - {exerrors.Must(account.NewAccount()), exerrors.Must(libolm.NewAccount())}, - } - - for _, combo := range combos { - receiver, sender := combo[0], combo[1] - require.NoError(t, receiver.GenOneTimeKeys(50)) - require.NoError(t, sender.GenOneTimeKeys(50)) - - _, receiverCurve25519, err := receiver.IdentityKeys() - require.NoError(t, err) - accountAOTKs, err := receiver.OneTimeKeys() - require.NoError(t, err) - - senderSession, err := sender.NewOutboundSession(receiverCurve25519, accountAOTKs[maps.Keys(accountAOTKs)[0]]) - require.NoError(t, err) - - // Send a couple pre-key messages from sender -> receiver. - var receiverSession olm.Session - for i := 0; i < 10; i++ { - msgType, ciphertext, err := senderSession.Encrypt([]byte(fmt.Sprintf("prekey %d", i))) - require.NoError(t, err) - assert.Equal(t, id.OlmMsgTypePreKey, msgType) - - receiverSession, err = receiver.NewInboundSession(string(ciphertext)) - require.NoError(t, err) - - decrypted, err := receiverSession.Decrypt(string(ciphertext), msgType) - require.NoError(t, err) - assert.Equal(t, []byte(fmt.Sprintf("prekey %d", i)), decrypted) - } - - // Send some messages from receiver -> sender. - for i := 0; i < 10; i++ { - msgType, ciphertext, err := receiverSession.Encrypt([]byte(fmt.Sprintf("response %d", i))) - require.NoError(t, err) - assert.Equal(t, id.OlmMsgTypeMsg, msgType) - - decrypted, err := senderSession.Decrypt(string(ciphertext), msgType) - require.NoError(t, err) - assert.Equal(t, []byte(fmt.Sprintf("response %d", i)), decrypted) - } - - // Send some more messages from sender -> receiver - for i := 0; i < 10; i++ { - msgType, ciphertext, err := senderSession.Encrypt([]byte(fmt.Sprintf("%d", i))) - require.NoError(t, err) - assert.Equal(t, id.OlmMsgTypeMsg, msgType) - - decrypted, err := receiverSession.Decrypt(string(ciphertext), msgType) - require.NoError(t, err) - assert.Equal(t, []byte(fmt.Sprintf("%d", i)), decrypted) - } - } -} diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go deleted file mode 100644 index 6b5b65fd..00000000 --- a/crypto/registergoolm.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build goolm - -package crypto - -import ( - "maunium.net/go/mautrix/crypto/goolm" -) - -func init() { - goolm.Register() -} diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go deleted file mode 100644 index ef78b6b5..00000000 --- a/crypto/registerlibolm.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !goolm - -package crypto - -import "maunium.net/go/mautrix/crypto/libolm" - -func init() { - libolm.Register() -} diff --git a/crypto/sessions.go b/crypto/sessions.go index ccc7b784..6075a644 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -8,7 +8,6 @@ package crypto import ( "errors" - "fmt" "time" "maunium.net/go/mautrix/crypto/olm" @@ -18,14 +17,8 @@ import ( ) var ( - ErrSessionNotShared = errors.New("session has not been shared") - ErrSessionExpired = errors.New("session has expired") -) - -// Deprecated: use variables prefixed with Err -var ( - SessionNotShared = ErrSessionNotShared - SessionExpired = ErrSessionExpired + SessionNotShared = errors.New("session has not been shared") + SessionExpired = errors.New("session has expired") ) // OlmSessionList is a list of OlmSessions. @@ -61,9 +54,9 @@ func (session *OlmSession) Describe() string { return session.Internal.Describe() } -func wrapSession(session olm.Session) *OlmSession { +func wrapSession(session *olm.Session) *OlmSession { return &OlmSession{ - Internal: session, + Internal: *session, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -75,7 +68,7 @@ func wrapSession(session olm.Session) *OlmSession { } func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) { - session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext) + session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext) if err != nil { return nil, err } @@ -83,7 +76,7 @@ func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, cipher return wrapSession(session), nil } -func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { +func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { session.LastEncryptedTime = time.Now() return session.Internal.Encrypt(plaintext) } @@ -117,7 +110,6 @@ type InboundGroupSession struct { MaxMessages int IsScheduled bool KeyBackupVersion id.KeyBackupVersion - KeySource id.KeySource id id.SessionID } @@ -128,16 +120,15 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI return nil, err } return &InboundGroupSession{ - Internal: igs, + Internal: *igs, SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, - ForwardingChains: []string{}, + ForwardingChains: nil, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: isScheduled, - KeySource: id.KeySourceDirect, }, nil } @@ -157,26 +148,10 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error { if err != nil { return err } - igs.Internal = imported + igs.Internal = *imported return nil } -func (igs *InboundGroupSession) export() (*ExportedSession, error) { - key, err := igs.Internal.Export(igs.Internal.FirstKnownIndex()) - if err != nil { - return nil, fmt.Errorf("failed to export session: %w", err) - } - return &ExportedSession{ - Algorithm: id.AlgorithmMegolmV1, - ForwardingChains: igs.ForwardingChains, - RoomID: igs.RoomID, - SenderKey: igs.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey}, - SessionID: igs.ID(), - SessionKey: string(key), - }, nil -} - type OGSState int const ( @@ -205,13 +180,9 @@ type OutboundGroupSession struct { content *event.RoomKeyEventContent } -func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) (*OutboundGroupSession, error) { - internal, err := olm.NewOutboundGroupSession() - if err != nil { - return nil, err - } +func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession { ogs := &OutboundGroupSession{ - Internal: internal, + Internal: *olm.NewOutboundGroupSession(), ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -235,7 +206,7 @@ func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.Encrypti ogs.MaxMessages = min(max(encryptionContent.RotationPeriodMessages, 1), 10000) } } - return ogs, nil + return ogs } func (ogs *OutboundGroupSession) ShareContent() event.Content { @@ -263,13 +234,13 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, ErrSessionNotShared + return nil, SessionNotShared } else if ogs.Expired() { - return nil, ErrSessionExpired + return nil, SessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() - return ogs.Internal.Encrypt(plaintext) + return ogs.Internal.Encrypt(plaintext), nil } type TimeMixin struct { diff --git a/crypto/sharing.go b/crypto/sharing.go index 10e37ccc..c0f3e209 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -173,19 +173,6 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven return } - // https://spec.matrix.org/v1.10/client-server-api/#msecretsend - // "The recipient must ensure... that the device is a verified device owned by the recipient" - if senderDevice, err := mach.GetOrFetchDevice(ctx, evt.Sender, evt.SenderDevice); err != nil { - log.Err(err).Msg("Failed to get or fetch sender device, rejecting secret") - return - } else if senderDevice == nil { - log.Warn().Msg("Unknown sender device, rejecting secret") - return - } else if !mach.IsDeviceTrusted(ctx, senderDevice) { - log.Warn().Msg("Sender device is not verified, rejecting secret") - return - } - mach.secretLock.Lock() secretChan := mach.secretListeners[content.RequestID] mach.secretLock.Unlock() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 138cc557..a8ccab26 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -13,7 +13,6 @@ import ( "encoding/json" "errors" "fmt" - "slices" "strings" "sync" "time" @@ -22,7 +21,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" @@ -53,18 +52,14 @@ var _ Store = (*SQLCryptoStore)(nil) // NewSQLCryptoStore initializes a new crypto Store using the given database, for a device's crypto material. // The stored material will be encrypted with the given key. func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID string, deviceID id.DeviceID, pickleKey []byte) *SQLCryptoStore { - store := &SQLCryptoStore{ + return &SQLCryptoStore{ DB: db.Child(sql_store_upgrade.VersionTableName, sql_store_upgrade.Table, log), PickleKey: pickleKey, AccountID: accountID, DeviceID: deviceID, - } - store.InitFields() - return store -} -func (store *SQLCryptoStore) InitFields() { - store.olmSessionCache = make(map[id.SenderKey]map[id.SessionID]*OlmSession) + olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession), + } } // Flush does nothing for this implementation as data is already persisted in the database. @@ -128,11 +123,8 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi // PutAccount stores an OlmAccount in the database. func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error { store.Account = account - bytes, err := account.Internal.Pickle(store.PickleKey) - if err != nil { - return err - } - _, err = store.DB.Exec(ctx, ` + bytes := account.Internal.Pickle(store.PickleKey) + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id, @@ -145,7 +137,7 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID) - acc := &OlmAccount{Internal: olm.NewBlankAccount()} + acc := &OlmAccount{Internal: *olm.NewBlankAccount()} var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion) if err == sql.ErrNoRows { @@ -191,7 +183,7 @@ func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) defer store.olmSessionCacheLock.Unlock() cache := store.getOlmSessionCache(key) for rows.Next() { - sess := OlmSession{Internal: olm.NewBlankSession()} + sess := OlmSession{Internal: *olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID err = rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) @@ -220,7 +212,7 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session return data } -// GetLatestSession retrieves the Olm session for a given sender key from the database that had the most recent successful decryption. +// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() @@ -228,7 +220,7 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", key, store.AccountID) - sess := OlmSession{Internal: olm.NewBlankSession()} + sess := OlmSession{Internal: *olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID @@ -250,26 +242,12 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender } } -// GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. -// This will exclude sessions that have never been used to encrypt or decrypt a message. -func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { - err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1", - key, store.AccountID).Scan(&createdAt) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - // AddSession persists an Olm session for a sender in the database. func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() - sessionBytes, err := session.Internal.Pickle(store.PickleKey) - if err != nil { - return err - } - _, err = store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID) store.getOlmSessionCache(key)[session.ID()] = session return err @@ -277,43 +255,12 @@ func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, s // UpdateSession replaces the Olm session for a sender in the database. func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { - sessionBytes, err := session.Internal.Pickle(store.PickleKey) - if err != nil { - return err - } - _, err = store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID) return err } -func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { - _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_session WHERE session_id=$1 AND account_id=$2", session.ID(), store.AccountID) - return err -} - -func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error { - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.AccountID, receivedAt.UnixMilli(), messageHash[:]) - return err -} - -func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) { - var receivedAtInt int64 - err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash[:]).Scan(&receivedAtInt) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return - } - receivedAt = time.UnixMilli(receivedAtInt) - return -} - -func (store *SQLCryptoStore) DeleteOldOlmHashes(ctx context.Context, beforeTS time.Time) error { - _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_message_hash WHERE account_id = $1 AND received_at < $2", store.AccountID, beforeTS.UnixMilli()) - return err -} - func datePtr(t time.Time) *time.Time { if t.IsZero() { return nil @@ -323,13 +270,7 @@ func datePtr(t time.Time) *time.Time { // PutGroupSession stores an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *InboundGroupSession) error { - sessionBytes, err := session.Internal.Pickle(store.PickleKey) - if err != nil { - return err - } - if session.ForwardingChains == nil { - session.ForwardingChains = []string{} - } + sessionBytes := session.Internal.Pickle(store.PickleKey) forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { @@ -346,23 +287,22 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou Int("max_messages", session.MaxMessages). Bool("is_scheduled", session.IsScheduled). Stringer("key_backup_version", session.KeyBackupVersion). - Stringer("key_source", session.KeySource). Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version, key_source=excluded.key_source + key_backup_version=excluded.key_backup_version `, session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), - session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, store.AccountID, ) return err } @@ -375,13 +315,12 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - var keySource id.KeySource err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, roomID, sessionID, store.AccountID, - ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -401,7 +340,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room return nil, err } return &InboundGroupSession{ - Internal: igs, + Internal: *igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, @@ -412,7 +351,6 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, - KeySource: keySource, }, nil } @@ -436,7 +374,10 @@ func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id. AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID) - return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() + if err != nil { + return nil, err + } + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -464,7 +405,10 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([] return nil, fmt.Errorf("unsupported dialect") } res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) - return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() + if err != nil { + return nil, err + } + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -474,7 +418,10 @@ func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([ WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID) - return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() + if err != nil { + return nil, err + } + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { @@ -508,7 +455,7 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID }, nil } -func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { +func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { igs = olm.NewBlankInboundGroupSession() err = igs.Unpickle(sessionBytes, store.PickleKey) if err != nil { @@ -516,8 +463,6 @@ func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSa } if forwardingChains != "" { chains = strings.Split(forwardingChains, ",") - } else { - chains = []string{} } var rs RatchetSafety if len(ratchetSafetyBytes) > 0 { @@ -537,8 +482,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - var keySource id.KeySource - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if err != nil { return nil, err } @@ -547,7 +491,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In return nil, err } return &InboundGroupSession{ - Internal: igs, + Internal: *igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, @@ -558,13 +502,12 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, - KeySource: keySource, }, nil } func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -573,7 +516,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -582,7 +525,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) @@ -591,11 +534,8 @@ func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { - sessionBytes, err := session.Internal.Pickle(store.PickleKey) - if err != nil { - return err - } - _, err = store.DB.Exec(ctx, ` + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -610,11 +550,8 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, sessio // UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID. func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { - sessionBytes, err := session.Internal.Pickle(store.PickleKey) - if err != nil { - return err - } - _, err = store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID) return err } @@ -639,7 +576,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID if err != nil { return nil, err } - ogs.Internal = intOGS + ogs.Internal = *intOGS ogs.RoomID = roomID ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond return &ogs, nil @@ -669,20 +606,6 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { - if eventID == "" && timestamp == 0 { - var notOK bool - const validateEmptyQuery = ` - SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3) - ` - err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK) - if notOK { - zerolog.Ctx(ctx).Debug(). - Uint("message_index", index). - Msg("Rejecting event without event ID and timestamp due to already knowing them") - } - return !notOK, err - } - const validateQuery = ` INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) @@ -729,8 +652,11 @@ func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) ( } rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) + if err != nil { + return nil, err + } data := make(map[id.DeviceID]*id.Device) - err = dbutil.NewRowIterWithError(rows, scanDevice, err).Iter(func(device *id.Device) (bool, error) { + err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) { data[device.DeviceID] = device return true, nil }) @@ -850,18 +776,19 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. placeholders, params := userIDsToParams(users) rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...) } - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() + if err != nil { + return users, err + } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() } // MarkTrackedUsersOutdated flags that the device list for given users are outdated. func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) { - for chunk := range slices.Chunk(users, 1000) { - if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { - _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(chunk)) - } else { - placeholders, params := userIDsToParams(chunk) - _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) - } + if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) + } else { + placeholders, params := userIDsToParams(users) + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) } return } @@ -869,7 +796,10 @@ func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) { rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE") - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() + if err != nil { + return nil, err + } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. @@ -954,7 +884,7 @@ func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id. } func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, value string) error { - bytes, err := libolmpickle.Pickle(store.PickleKey, []byte(value)) + bytes, err := cipher.Pickle(store.PickleKey, []byte(value)) if err != nil { return err } @@ -973,7 +903,7 @@ func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (val } else if err != nil { return "", err } - bytes, err = libolmpickle.Unpickle(store.PickleKey, bytes) + bytes, err = cipher.Unpickle(store.PickleKey, bytes) return string(bytes), err } diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 3709f1e5..7e039af5 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v19 (compatible with v15+): Latest revision +-- v0 -> v15: Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -43,17 +43,6 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) ); -CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); - -CREATE TABLE crypto_olm_message_hash ( - account_id TEXT NOT NULL, - received_at BIGINT NOT NULL, - message_hash bytea NOT NULL PRIMARY KEY, - - CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) - REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE -); -CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, @@ -71,11 +60,8 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, key_backup_version TEXT NOT NULL DEFAULT '', - key_source TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); --- Useful index to find keys that need backing up -CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, diff --git a/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql b/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql deleted file mode 100644 index f0c3a0c5..00000000 --- a/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v16 (compatible with v15+): Add index to crypto_olm_sessions to speedup lookups by sender_key -CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); diff --git a/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql deleted file mode 100644 index 525bbb52..00000000 --- a/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql +++ /dev/null @@ -1,11 +0,0 @@ --- v17 (compatible with v15+): Add table for decrypted Olm message hashes -CREATE TABLE crypto_olm_message_hash ( - account_id TEXT NOT NULL, - received_at BIGINT NOT NULL, - message_hash bytea NOT NULL PRIMARY KEY, - - CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) - REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE -); - -CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql deleted file mode 100644 index da26da0f..00000000 --- a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster -CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql deleted file mode 100644 index f624222f..00000000 --- a/crypto/sql_store_upgrade/19-megolm-session-source.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v19 (compatible with v15+): Store megolm session source -ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT ''; diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index 8691d032..e30925d9 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -95,22 +95,6 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } -// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it, -// alongside the unencrypted metadata, on the server. -func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error { - if len(keys) == 0 { - return ErrNoKeyGiven - } - encrypted := make(map[string]EncryptedKeyData, len(keys)) - for _, key := range keys { - encrypted[key.ID] = key.Encrypt(eventType.Type, data) - } - return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{ - Encrypted: encrypted, - Metadata: metadata, - }) -} - // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index 78ebd8f3..c973c1fe 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -7,8 +7,6 @@ package ssss import ( - "crypto/hmac" - "crypto/sha256" "encoding/base64" "fmt" "strings" @@ -59,12 +57,7 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - macBytes, err := keyData.calculateHash(ssssKey) - if err != nil { - // This should never happen because we just generated the IV and key. - return nil, fmt.Errorf("failed to calculate hash: %w", err) - } - keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes) + keyData.MAC = keyData.calculateHash(ssssKey) return &Key{ Key: ssssKey, @@ -110,18 +103,12 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) return nil, err } - mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "=")) - if err != nil { - return nil, err - } - // derive the AES and HMAC keys for the requested event type using the SSSS key aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) // compare the stored MAC with the one we calculated from the ciphertext - h := hmac.New(sha256.New, hmacKey[:]) - h.Write(payload) - if !hmac.Equal(h.Sum(nil), mac) { + calcMac := utils.HMACSHA256B64(payload, hmacKey) + if strings.TrimRight(data.MAC, "=") != calcMac { return nil, ErrKeyDataMACMismatch } diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 34775fa7..210bcdcf 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -7,10 +7,7 @@ package ssss import ( - "crypto/hmac" - "crypto/sha256" "encoding/base64" - "errors" "fmt" "strings" @@ -36,10 +33,8 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } - err = kd.verifyKey(ssssKey) - if err != nil && !errors.Is(err, ErrUnverifiableKey) { - return nil, err + } else if !kd.VerifyKey(ssssKey) { + return nil, ErrIncorrectSSSSKey } return &Key{ @@ -54,70 +49,33 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } - err := kd.verifyKey(ssssKey) - if err != nil && !errors.Is(err, ErrUnverifiableKey) { - return nil, err + } else if !kd.VerifyKey(ssssKey) { + return nil, ErrIncorrectSSSSKey } return &Key{ ID: keyID, Key: ssssKey, Metadata: kd, - }, err -} - -func (kd *KeyMetadata) verifyKey(key []byte) error { - if kd.MAC == "" || kd.IV == "" { - return ErrUnverifiableKey - } - unpaddedMAC := strings.TrimRight(kd.MAC, "=") - expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) - if len(unpaddedMAC) != expectedMACLength { - return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) - } - expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC) - if err != nil { - return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err) - } - calculatedMAC, err := kd.calculateHash(key) - if err != nil { - return err - } - // This doesn't really need to be constant time since it's fully local, but might as well be. - if !hmac.Equal(expectedMAC, calculatedMAC) { - return ErrIncorrectSSSSKey - } - return nil + }, nil } // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { - return kd.verifyKey(key) == nil + return strings.TrimRight(kd.MAC, "=") == kd.calculateHash(key) } // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) { +func (kd *KeyMetadata) calculateHash(key []byte) string { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") - unpaddedIV := strings.TrimRight(kd.IV, "=") - expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) - if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 { - return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) - } - rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV) - if err != nil { - return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) - } - // TODO log a warning for non-16 byte IVs? - // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used. - ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength]) - zeroes := make([]byte, utils.AESCTRKeyLength) - encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes) - h := hmac.New(sha256.New, hmacKey[:]) - h.Write(encryptedZeroes) - return h.Sum(nil), nil + var ivBytes [utils.AESCTRIVLength]byte + _, _ = base64.RawStdEncoding.Decode(ivBytes[:], []byte(strings.TrimRight(kd.IV, "="))) + + cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) + + return utils.HMACSHA256B64(cipher, hmacKey) } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index d59809c7..96c97282 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -8,10 +8,10 @@ package ssss_test import ( "encoding/json" + "errors" "testing" "github.com/stretchr/testify/assert" - "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/crypto/ssss" ) @@ -41,42 +41,12 @@ const key2Meta = ` } ` -const key2MetaUnverified = ` -{ - "algorithm": "m.secret_storage.v1.aes-hmac-sha2" -} -` - -const key2MetaLongIV = ` -{ - "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=", - "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" -} -` - -const key2MetaBrokenIV = ` -{ - "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "MeowMeowMeow", - "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" -} -` - -const key2MetaBrokenMAC = ` -{ - "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "O0BOvTqiIAYjC+RMcyHfWw==", - "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" -} -` - const key2ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" const key2RecoveryKey = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" -func getKeyMeta(meta string) *ssss.KeyMetadata { +func getKey1Meta() *ssss.KeyMetadata { var km ssss.KeyMetadata - err := json.Unmarshal([]byte(meta), &km) + err := json.Unmarshal([]byte(key1Meta), &km) if err != nil { panic(err) } @@ -84,15 +54,36 @@ func getKeyMeta(meta string) *ssss.KeyMetadata { } func getKey1() *ssss.Key { - return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey)) + km := getKey1Meta() + key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) + if err != nil { + panic(err) + } + key.ID = key1ID + return key +} + +func getKey2Meta() *ssss.KeyMetadata { + var km ssss.KeyMetadata + err := json.Unmarshal([]byte(key2Meta), &km) + if err != nil { + panic(err) + } + return &km } func getKey2() *ssss.Key { - return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey)) + km := getKey2Meta() + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + if err != nil { + panic(err) + } + key.ID = key2ID + return key } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { - km := getKeyMeta(key1Meta) + km := getKey1Meta() key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) @@ -100,45 +91,29 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { } func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { - km := getKeyMeta(key2Meta) + km := getKey2Meta() key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } -func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) { - km := getKeyMeta(key2MetaLongIV) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.NoError(t, err) - assert.NotNil(t, key) - assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) -} - -func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) { - km := getKeyMeta(key2MetaUnverified) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrUnverifiableKey) - assert.NotNil(t, key) - assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) -} - func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { - km := getKeyMeta(key1Meta) + km := getKey1Meta() key, err := km.VerifyRecoveryKey(key1ID, "foo") - assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey) + assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { - km := getKeyMeta(key1Meta) + km := getKey1Meta() key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) + assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { - km := getKeyMeta(key1Meta) + km := getKey1Meta() key, err := km.VerifyPassphrase(key1ID, key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) @@ -146,29 +121,15 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { } func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { - km := getKeyMeta(key1Meta) + km := getKey1Meta() key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") - assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) + assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { - km := getKeyMeta(key2Meta) + km := getKey2Meta() key, err := km.VerifyPassphrase(key2ID, "hmm") - assert.ErrorIs(t, err, ssss.ErrNoPassphrase) - assert.Nil(t, key) -} - -func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { - km := getKeyMeta(key2MetaBrokenIV) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) - assert.Nil(t, key) -} - -func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { - km := getKeyMeta(key2MetaBrokenMAC) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) + assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) assert.Nil(t, key) } diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index b7465d3e..60852c55 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,8 +26,6 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") - ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata") - ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. @@ -58,7 +56,6 @@ type EncryptedKeyData struct { type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` - Metadata map[string]any `json:"com.beeper.metadata,omitzero"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { diff --git a/crypto/store.go b/crypto/store.go index 7620cf35..a84d4f13 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -9,13 +9,10 @@ package crypto import ( "context" "fmt" - "slices" "sort" "sync" - "time" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exsync" "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" @@ -45,22 +42,11 @@ type Store interface { HasSession(context.Context, id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. GetSessions(context.Context, id.SenderKey) (OlmSessionList, error) - // GetLatestSession returns the most recent session that should be used for encrypting outbound messages. - // It's usually the one with the most recent successful decryption or the highest ID lexically. + // GetLatestSession returns the session with the highest session ID (lexiographically sorting). + // It's usually safe to return the most recently added session if sorting by session ID is too difficult. GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) - // GetNewestSessionCreationTS returns the creation timestamp of the most recently created session for the given sender key. - GetNewestSessionCreationTS(context.Context, id.SenderKey) (time.Time, error) // UpdateSession updates a session that has previously been inserted with AddSession. UpdateSession(context.Context, id.SenderKey, *OlmSession) error - // DeleteSession deletes the given session that has been previously inserted with AddSession. - DeleteSession(context.Context, id.SenderKey, *OlmSession) error - - // PutOlmHash marks a given olm message hash as handled. - PutOlmHash(context.Context, [32]byte, time.Time) error - // GetOlmHash gets the time that a given olm hash was handled. - GetOlmHash(context.Context, [32]byte) (time.Time, error) - // DeleteOldOlmHashes deletes all olm hashes that were handled before the given time. - DeleteOldOlmHashes(context.Context, time.Time) error // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace @@ -146,8 +132,6 @@ type Store interface { IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error) - // GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer. - GetSignaturesForKeyBy(context.Context, id.UserID, id.Ed25519, id.UserID) (map[id.Ed25519]string, error) // PutSecret stores a named secret, replacing it if it exists already. PutSecret(context.Context, id.Secret, string) error @@ -187,7 +171,6 @@ type MemoryStore struct { KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string OutdatedUsers map[id.UserID]struct{} Secrets map[id.Secret]string - OlmHashes *exsync.Set[[32]byte] } var _ Store = (*MemoryStore)(nil) @@ -210,7 +193,6 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), OutdatedUsers: make(map[id.UserID]struct{}), Secrets: make(map[id.Secret]string), - OlmHashes: exsync.NewSet[[32]byte](), } } @@ -251,19 +233,6 @@ func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, ses return gs.save() } -func (gs *MemoryStore) DeleteSession(ctx context.Context, senderKey id.SenderKey, target *OlmSession) error { - gs.lock.Lock() - defer gs.lock.Unlock() - sessions, ok := gs.Sessions[senderKey] - if !ok { - return nil - } - gs.Sessions[senderKey] = slices.DeleteFunc(sessions, func(session *OlmSession) bool { - return session == target - }) - return gs.save() -} - func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() @@ -276,23 +245,6 @@ func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) boo return ok && len(sessions) > 0 && !sessions[0].Expired() } -func (gs *MemoryStore) PutOlmHash(_ context.Context, hash [32]byte, receivedAt time.Time) error { - gs.OlmHashes.Add(hash) - return nil -} - -func (gs *MemoryStore) GetOlmHash(_ context.Context, hash [32]byte) (time.Time, error) { - if gs.OlmHashes.Has(hash) { - // The time isn't that important, so we just return the current time - return time.Now(), nil - } - return time.Time{}, nil -} - -func (gs *MemoryStore) DeleteOldOlmHashes(_ context.Context, beforeTS time.Time) error { - return nil -} - func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() defer gs.lock.RUnlock() @@ -300,16 +252,7 @@ func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKe if !ok || len(sessions) == 0 { return nil, nil } - return sessions[len(sessions)-1], nil -} - -func (gs *MemoryStore) GetNewestSessionCreationTS(ctx context.Context, senderKey id.SenderKey) (createdAt time.Time, err error) { - var sess *OlmSession - sess, err = gs.GetLatestSession(ctx, senderKey) - if sess != nil { - createdAt = sess.CreationTime - } - return + return sessions[0], nil } func (gs *MemoryStore) getGroupSessions(roomID id.RoomID) map[id.SessionID]*InboundGroupSession { @@ -525,9 +468,6 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send } val, ok := gs.MessageIndices[key] if !ok { - if eventID == "" && timestamp == 0 { - return true, nil - } gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, diff --git a/crypto/store_test.go b/crypto/store_test.go index 7a47243e..740273dd 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,8 +13,6 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/crypto/olm" @@ -30,14 +28,22 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4 func getCryptoStores(t *testing.T) map[string]Store { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - require.NoError(t, err, "Error opening raw database") + if err != nil { + t.Fatalf("Error opening db: %v", err) + } db, err := dbutil.NewWithDB(rawDB, "sqlite3") - require.NoError(t, err, "Error creating database wrapper") + if err != nil { + t.Fatalf("Error opening db: %v", err) + } sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - err = sqlStore.DB.Upgrade(context.TODO()) - require.NoError(t, err, "Error upgrading database") + if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { + t.Fatalf("Error creating tables: %v", err) + } gobStore := NewMemoryStore(nil) + if err != nil { + t.Fatalf("Error creating Gob store: %v", err) + } return map[string]Store{ "sql": sqlStore, @@ -49,10 +55,9 @@ func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) store.PutNextBatch(context.Background(), "batch1") - - batch, err := store.GetNextBatch(context.Background()) - require.NoError(t, err, "Error retrieving next batch") - assert.Equal(t, "batch1", batch) + if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { + t.Errorf("Expected batch1, got %v", batch) + } } func TestPutAccount(t *testing.T) { @@ -62,9 +67,15 @@ func TestPutAccount(t *testing.T) { acc := NewOlmAccount() store.PutAccount(context.TODO(), acc) retrieved, err := store.GetAccount(context.TODO()) - require.NoError(t, err, "Error retrieving account") - assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match") - assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match") + if err != nil { + t.Fatalf("Error retrieving account: %v", err) + } + if acc.IdentityKey() != retrieved.IdentityKey() { + t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey()) + } + if acc.SigningKey() != retrieved.SigningKey() { + t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey()) + } }) } } @@ -74,36 +85,18 @@ func TestValidateMessageIndex(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - - // Validating without event ID and timestamp before we have them should work - ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) - require.NoError(t, err, "Error validating message index") - assert.True(t, ok, "First message validation should be valid") - - // First message should validate successfully - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) - require.NoError(t, err, "Error validating message index") - assert.True(t, ok, "First message validation should be valid") - - // Edit the timestamp and ensure validate fails - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001) - require.NoError(t, err, "Error validating message index after timestamp change") - assert.False(t, ok, "First message validation should fail after timestamp change") - - // Edit the event ID and ensure validate fails - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000) - require.NoError(t, err, "Error validating message index after event ID change") - assert.False(t, ok, "First message validation should fail after event ID change") - - // Validate again with the original parameters and ensure that it still passes - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) - require.NoError(t, err, "Error validating message index") - assert.True(t, ok, "First message validation should be valid") - - // Validating without event ID and timestamp must fail if we already know them - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) - require.NoError(t, err, "Error validating message index") - assert.False(t, ok, "First message validation should be invalid") + if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { + t.Error("First message not validated successfully") + } + if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok { + t.Error("First message validated successfully after changing timestamp") + } + if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok { + t.Error("First message validated successfully after changing event ID") + } + if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { + t.Error("First message not validated successfully for a second time") + } }) } } @@ -112,26 +105,37 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it") - + if store.HasSession(context.TODO(), olmSessID) { + t.Error("Found Olm session before inserting it") + } olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) - require.NoError(t, err, "Error creating internal Olm session") + if err != nil { + t.Fatalf("Error creating internal Olm session: %v", err) + } olmSess := OlmSession{ id: olmSessID, - Internal: olmInternal, + Internal: *olmInternal, } err = store.AddSession(context.TODO(), olmSessID, &olmSess) - require.NoError(t, err, "Error storing Olm session") - assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it") + if err != nil { + t.Errorf("Error storing Olm session: %v", err) + } + if !store.HasSession(context.TODO(), olmSessID) { + t.Error("Not found Olm session after inserting it") + } retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) - require.NoError(t, err, "Error retrieving Olm session") - assert.EqualValues(t, olmSessID, retrieved.ID()) + if err != nil { + t.Errorf("Failed retrieving Olm session: %v", err) + } - pickled, err := retrieved.Internal.Pickle([]byte("test")) - require.NoError(t, err, "Error pickling Olm session") - assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original") + if retrieved.ID() != olmSessID { + t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) + } + if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled { + t.Error("Pickled Olm session does not match original") + } }) } } @@ -143,24 +147,30 @@ func TestStoreMegolmSession(t *testing.T) { acc := NewOlmAccount() internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) - require.NoError(t, err, "Error creating internal inbound group session") + if err != nil { + t.Fatalf("Error creating internal inbound group session: %v", err) + } igs := &InboundGroupSession{ - Internal: internal, + Internal: *internal, SigningKey: acc.SigningKey(), SenderKey: acc.IdentityKey(), RoomID: "room1", } err = store.PutGroupSession(context.TODO(), igs) - require.NoError(t, err, "Error storing inbound group session") + if err != nil { + t.Errorf("Error storing inbound group session: %v", err) + } retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) - require.NoError(t, err, "Error retrieving inbound group session") + if err != nil { + t.Errorf("Error retrieving inbound group session: %v", err) + } - pickled, err := retrieved.Internal.Pickle([]byte("test")) - require.NoError(t, err, "Error pickling inbound group session") - assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original") + if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession { + t.Error("Pickled inbound group session does not match original") + } }) } } @@ -170,24 +180,39 @@ func TestStoreOutboundMegolmSession(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") - require.NoError(t, err, "Error retrieving outbound session") - require.Nil(t, sess, "Got outbound session before inserting") + if sess != nil { + t.Error("Got outbound session before inserting") + } + if err != nil { + t.Errorf("Error retrieving outbound session: %v", err) + } - outbound, err := NewOutboundGroupSession("room1", nil) - require.NoError(t, err) + outbound := NewOutboundGroupSession("room1", nil) err = store.AddOutboundGroupSession(context.TODO(), outbound) - require.NoError(t, err, "Error inserting outbound session") + if err != nil { + t.Errorf("Error inserting outbound session: %v", err) + } sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - require.NoError(t, err, "Error retrieving outbound session") - assert.NotNil(t, sess, "Did not get outbound session after inserting") + if sess == nil { + t.Error("Did not get outbound session after inserting") + } + if err != nil { + t.Errorf("Error retrieving outbound session: %v", err) + } err = store.RemoveOutboundGroupSession(context.TODO(), "room1") - require.NoError(t, err, "Error deleting outbound session") + if err != nil { + t.Errorf("Error deleting outbound session: %v", err) + } sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - require.NoError(t, err, "Error retrieving outbound session after deletion") - assert.Nil(t, sess, "Got outbound session after deleting") + if sess != nil { + t.Error("Got outbound session after deleting") + } + if err != nil { + t.Errorf("Error retrieving outbound session: %v", err) + } }) } } @@ -209,41 +234,58 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) { t.Run(storeName, func(t *testing.T) { device := resetDevice() err := store.PutDevice(context.TODO(), "user1", device) - require.NoError(t, err, "Error storing device") + if err != nil { + t.Errorf("Error storing devices: %v", err) + } shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - require.NoError(t, err, "Error checking if outbound group session is shared") - assert.False(t, shared, "Outbound group session should not be shared initially") + if err != nil { + t.Errorf("Error checking if outbound group session is shared: %v", err) + } else if shared { + t.Errorf("Outbound group session shared when it shouldn't") + } err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - require.NoError(t, err, "Error marking outbound group session as shared") + if err != nil { + t.Errorf("Error marking outbound group session as shared: %v", err) + } shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - require.NoError(t, err, "Error checking if outbound group session is shared") - assert.True(t, shared, "Outbound group session should be shared after marking it as such") + if err != nil { + t.Errorf("Error checking if outbound group session is shared: %v", err) + } else if !shared { + t.Errorf("Outbound group session not shared when it should") + } device = resetDevice() err = store.PutDevice(context.TODO(), "user1", device) - require.NoError(t, err, "Error storing device after resetting") + if err != nil { + t.Errorf("Error storing devices: %v", err) + } shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - require.NoError(t, err, "Error checking if outbound group session is shared") - assert.False(t, shared, "Outbound group session should not be shared after resetting device") + if err != nil { + t.Errorf("Error checking if outbound group session is shared: %v", err) + } else if shared { + t.Errorf("Outbound group session shared when it shouldn't") + } }) } } func TestStoreDevices(t *testing.T) { - devicesToCreate := 17 stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) - require.NoError(t, err, "Error filtering tracked users") - assert.Empty(t, outdated, "Expected no outdated tracked users initially") - + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) + } deviceMap := make(map[id.DeviceID]*id.Device) - for i := 0; i < devicesToCreate; i++ { + for i := 0; i < 17; i++ { iStr := strconv.Itoa(i) acc := NewOlmAccount() deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{ @@ -254,33 +296,59 @@ func TestStoreDevices(t *testing.T) { } } err = store.PutDevices(context.TODO(), "user1", deviceMap) - require.NoError(t, err, "Error storing devices") + if err != nil { + t.Errorf("Error storing devices: %v", err) + } devs, err := store.GetDevices(context.TODO(), "user1") - require.NoError(t, err, "Error getting devices") - assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate) - assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices") + if err != nil { + t.Errorf("Error getting devices: %v", err) + } + if len(devs) != 17 { + t.Errorf("Stored 17 devices, got back %v", len(devs)) + } + if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey { + t.Errorf("First device identity key does not match") + } + if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey { + t.Errorf("Last device identity key does not match") + } filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) - require.NoError(t, err, "Error filtering tracked users") - assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter") + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } else if len(filtered) != 1 || filtered[0] != "user1" { + t.Errorf("Expected to get 'user1' from filter, got %v", filtered) + } outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - require.NoError(t, err, "Error filtering tracked users") - assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage") - + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) + } err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) - require.NoError(t, err, "Error marking tracked users outdated") - + if err != nil { + t.Errorf("Error marking tracked users outdated: %v", err) + } outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - require.NoError(t, err, "Error filtering tracked users") - assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated") - + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) != 1 || outdated[0] != id.UserID("user1") { + t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) + } err = store.PutDevices(context.TODO(), "user1", deviceMap) - require.NoError(t, err, "Error storing devices again") - + if err != nil { + t.Errorf("Error storing devices: %v", err) + } outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - require.NoError(t, err, "Error filtering tracked users") - assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices") + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got outdated tracked users %v when expected none", outdated) + } }) } } @@ -291,11 +359,16 @@ func TestStoreSecrets(t *testing.T) { t.Run(storeName, func(t *testing.T) { storedSecret := "trustno1" err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret) - require.NoError(t, err, "Error storing secret") + if err != nil { + t.Errorf("Error storing secret: %v", err) + } secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1) - require.NoError(t, err, "Error retrieving secret") - assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret") + if err != nil { + t.Errorf("Error storing secret: %v", err) + } else if secret != storedSecret { + t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret) + } }) } } diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go index b12fd9e2..c4f01a68 100644 --- a/crypto/utils/utils_test.go +++ b/crypto/utils/utils_test.go @@ -9,9 +9,6 @@ package utils import ( "encoding/base64" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestAES256Ctr(t *testing.T) { @@ -19,7 +16,9 @@ func TestAES256Ctr(t *testing.T) { key, iv := GenAttachmentA256CTR() enc := XorA256CTR([]byte(expected), key, iv) dec := XorA256CTR(enc, key, iv) - assert.EqualValues(t, expected, dec, "Decrypted text should match original") + if string(dec) != expected { + t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec)) + } var key2 [AESCTRKeyLength]byte var iv2 [AESCTRIVLength]byte @@ -30,7 +29,9 @@ func TestAES256Ctr(t *testing.T) { iv2[i] = byte(i) + 32 } dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2) - assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original") + if string(dec2) != expected { + t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2)) + } } func TestPBKDF(t *testing.T) { @@ -41,7 +42,9 @@ func TestPBKDF(t *testing.T) { key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256) expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E=" keyB64 := base64.StdEncoding.EncodeToString([]byte(key)) - assert.Equal(t, expected, keyB64) + if keyB64 != expected { + t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64) + } } func TestDecodeSSSSKey(t *testing.T) { @@ -50,10 +53,13 @@ func TestDecodeSSSSKey(t *testing.T) { expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw=" decodedB64 := base64.StdEncoding.EncodeToString(decoded[:]) - assert.Equal(t, expected, decodedB64) + if expected != decodedB64 { + t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64) + } - encoded := EncodeBase58RecoveryKey(decoded) - assert.Equal(t, recoveryKey, encoded) + if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey { + t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded) + } } func TestKeyDerivationAndHMAC(t *testing.T) { @@ -63,11 +69,15 @@ func TestKeyDerivationAndHMAC(t *testing.T) { aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master") ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=") - require.NoError(t, err) + if err != nil { + t.Error(err) + } calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E" - assert.Equal(t, expectedMac, calcMac) + if calcMac != expectedMac { + t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) + } var ivBytes [AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==") @@ -75,5 +85,7 @@ func TestKeyDerivationAndHMAC(t *testing.T) { decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes)) expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s=" - assert.Equal(t, expectedDec, decrypted) + if expectedDec != decrypted { + t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted) + } } diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 3b943f28..7b1055d1 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -17,26 +17,20 @@ import ( type MockVerificationCallbacks interface { GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID GetScanQRCodeTransactions() []id.VerificationTransactionID - GetVerificationsReadyTransactions() []id.VerificationTransactionID GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode } type baseVerificationCallbacks struct { scanQRCodeTransactions []id.VerificationTransactionID verificationsRequested map[id.UserID][]id.VerificationTransactionID - verificationsReady []id.VerificationTransactionID qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode qrCodesScanned map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent emojisShown map[id.VerificationTransactionID][]rune - emojiDescriptionsShown map[id.VerificationTransactionID][]string decimalsShown map[id.VerificationTransactionID][]int } -var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) -var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil) - func newBaseVerificationCallbacks() *baseVerificationCallbacks { return &baseVerificationCallbacks{ verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, @@ -45,7 +39,6 @@ func newBaseVerificationCallbacks() *baseVerificationCallbacks { doneTransactions: map[id.VerificationTransactionID]struct{}{}, verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, emojisShown: map[id.VerificationTransactionID][]rune{}, - emojiDescriptionsShown: map[id.VerificationTransactionID][]string{}, decimalsShown: map[id.VerificationTransactionID][]int{}, } } @@ -58,10 +51,6 @@ func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.Verificatio return c.scanQRCodeTransactions } -func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID { - return c.verificationsReady -} - func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode { return c.qrCodesShown[txnID] } @@ -80,28 +69,18 @@ func (c *baseVerificationCallbacks) GetVerificationCancellation(txnID id.Verific return c.verificationCancellation[txnID] } -func (c *baseVerificationCallbacks) GetEmojisAndDescriptionsShown(txnID id.VerificationTransactionID) ([]rune, []string) { - return c.emojisShown[txnID], c.emojiDescriptionsShown[txnID] +func (c *baseVerificationCallbacks) GetEmojisShown(txnID id.VerificationTransactionID) []rune { + return c.emojisShown[txnID] } func (c *baseVerificationCallbacks) GetDecimalsShown(txnID id.VerificationTransactionID) []int { return c.decimalsShown[txnID] } -func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) { +func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) { c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) } -func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, allowScanQRCode bool, qrCode *verificationhelper.QRCode) { - c.verificationsReady = append(c.verificationsReady, txnID) - if allowScanQRCode { - c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) - } - if qrCode != nil { - c.qrCodesShown[txnID] = qrCode - } -} - func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) { c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{ Code: code, @@ -109,7 +88,7 @@ func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, t } } -func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) { +func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) { c.doneTransactions[txnID] = struct{}{} } @@ -117,8 +96,6 @@ type sasVerificationCallbacks struct { *baseVerificationCallbacks } -var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil) - func newSASVerificationCallbacks() *sasVerificationCallbacks { return &sasVerificationCallbacks{newBaseVerificationCallbacks()} } @@ -127,34 +104,39 @@ func newSASVerificationCallbacksWithBase(base *baseVerificationCallbacks) *sasVe return &sasVerificationCallbacks{base} } -func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) { +func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) { c.emojisShown[txnID] = emojis - c.emojiDescriptionsShown[txnID] = emojiDescriptions c.decimalsShown[txnID] = decimals } -type showQRCodeVerificationCallbacks struct { +type qrCodeVerificationCallbacks struct { *baseVerificationCallbacks } -var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil) - -func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks { - return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} +func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks { + return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()} } -func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks { - return &showQRCodeVerificationCallbacks{base} +func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks { + return &qrCodeVerificationCallbacks{base} } -func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { + c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) +} + +func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { + c.qrCodesShown[txnID] = qrCode +} + +func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { c.qrCodesScanned[txnID] = struct{}{} } type allVerificationCallbacks struct { *baseVerificationCallbacks *sasVerificationCallbacks - *showQRCodeVerificationCallbacks + *qrCodeVerificationCallbacks } func newAllVerificationCallbacks() *allVerificationCallbacks { @@ -162,6 +144,6 @@ func newAllVerificationCallbacks() *allVerificationCallbacks { return &allVerificationCallbacks{ base, newSASVerificationCallbacksWithBase(base), - newShowQRCodeVerificationCallbacksWithBase(base), + newQRCodeVerificationCallbacksWithBase(base), } } diff --git a/crypto/verificationhelper/ecdhkeys.go b/crypto/verificationhelper/ecdhkeys.go deleted file mode 100644 index 754530ed..00000000 --- a/crypto/verificationhelper/ecdhkeys.go +++ /dev/null @@ -1,57 +0,0 @@ -package verificationhelper - -import ( - "crypto/ecdh" - "encoding/json" -) - -type ECDHPrivateKey struct { - *ecdh.PrivateKey -} - -func (e *ECDHPrivateKey) UnmarshalJSON(data []byte) (err error) { - if len(data) == 0 { - return nil - } - var raw []byte - err = json.Unmarshal(data, &raw) - if err != nil { - return - } - if len(raw) == 0 { - return nil - } - e.PrivateKey, err = ecdh.X25519().NewPrivateKey(raw) - return err -} - -func (e ECDHPrivateKey) MarshalJSON() ([]byte, error) { - if e.PrivateKey == nil { - return json.Marshal(nil) - } - return json.Marshal(e.Bytes()) -} - -type ECDHPublicKey struct { - *ecdh.PublicKey -} - -func (e *ECDHPublicKey) UnmarshalJSON(data []byte) (err error) { - if len(data) == 0 { - return nil - } - var raw []byte - err = json.Unmarshal(data, &raw) - if err != nil { - return - } - if len(raw) == 0 { - return nil - } - e.PublicKey, err = ecdh.X25519().NewPublicKey(raw) - return -} - -func (e ECDHPublicKey) MarshalJSON() ([]byte, error) { - return json.Marshal(e.Bytes()) -} diff --git a/crypto/verificationhelper/ecdhkeys_test.go b/crypto/verificationhelper/ecdhkeys_test.go deleted file mode 100644 index 109fbf88..00000000 --- a/crypto/verificationhelper/ecdhkeys_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package verificationhelper_test - -import ( - "crypto/ecdh" - "crypto/rand" - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/verificationhelper" -) - -func TestECDHPrivateKey(t *testing.T) { - pk, err := ecdh.X25519().GenerateKey(rand.Reader) - require.NoError(t, err) - private := verificationhelper.ECDHPrivateKey{pk} - marshalled, err := json.Marshal(private) - require.NoError(t, err) - - assert.Len(t, marshalled, 46) - - var unmarshalled verificationhelper.ECDHPrivateKey - err = json.Unmarshal(marshalled, &unmarshalled) - require.NoError(t, err) - - assert.True(t, private.Equal(unmarshalled.PrivateKey)) -} - -func TestECDHPublicKey(t *testing.T) { - private, err := ecdh.X25519().GenerateKey(rand.Reader) - require.NoError(t, err) - - public := private.PublicKey() - - pub := verificationhelper.ECDHPublicKey{public} - marshalled, err := json.Marshal(pub) - require.NoError(t, err) - - assert.Len(t, marshalled, 46) - - var unmarshalled verificationhelper.ECDHPublicKey - err = json.Unmarshal(marshalled, &unmarshalled) - require.NoError(t, err) - - assert.True(t, public.Equal(unmarshalled.PublicKey)) -} diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go new file mode 100644 index 00000000..e35f51b2 --- /dev/null +++ b/crypto/verificationhelper/mockserver_test.go @@ -0,0 +1,255 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gorilla/mux" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow +// testing of the interactive verification process. +type mockServer struct { + *httptest.Server + + AccessTokenToUserID map[string]id.UserID + DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event + AccountData map[id.UserID]map[event.Type]json.RawMessage + DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + MasterKeys map[id.UserID]mautrix.CrossSigningKeys + SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys + UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys +} + +func DecodeVarsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + var err error + for k, v := range vars { + vars[k], err = url.PathUnescape(v) + if err != nil { + panic(err) + } + } + next.ServeHTTP(w, r) + }) +} + +func createMockServer(t *testing.T) *mockServer { + t.Helper() + + server := mockServer{ + AccessTokenToUserID: map[string]id.UserID{}, + DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, + AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + } + + router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath() + router.Use(DecodeVarsMiddleware) + router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut) + router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut) + router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost) + + server.Server = httptest.NewServer(router) + return &server +} + +func (ms *mockServer) getUserID(r *http.Request) id.UserID { + authHeader := r.Header.Get("Authorization") + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + userID, ok := ms.AccessTokenToUserID[authHeader] + if !ok { + panic("no user ID found for access token " + authHeader) + } + return userID +} + +func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("{}")) +} + +func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { + var loginReq mautrix.ReqLogin + json.NewDecoder(r.Body).Decode(&loginReq) + + deviceID := loginReq.DeviceID + if deviceID == "" { + deviceID = id.DeviceID(random.String(10)) + } + + accessToken := random.String(30) + userID := id.UserID(loginReq.Identifier.User) + s.AccessTokenToUserID[accessToken] = userID + + json.NewEncoder(w).Encode(&mautrix.RespLogin{ + AccessToken: accessToken, + DeviceID: deviceID, + UserID: userID, + }) +} + +func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + var req mautrix.ReqSendToDevice + json.NewDecoder(r.Body).Decode(&req) + evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType} + + for user, devices := range req.Messages { + for device, content := range devices { + if _, ok := s.DeviceInbox[user]; !ok { + s.DeviceInbox[user] = map[id.DeviceID][]event.Event{} + } + content.ParseRaw(evtType) + s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{ + Sender: s.getUserID(r), + Type: evtType, + Content: *content, + }) + } + } + s.emptyResp(w, r) +} + +func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userID := id.UserID(vars["userID"]) + eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType} + + jsonData, _ := io.ReadAll(r.Body) + if _, ok := s.AccountData[userID]; !ok { + s.AccountData[userID] = map[event.Type]json.RawMessage{} + } + s.AccountData[userID][eventType] = json.RawMessage(jsonData) + s.emptyResp(w, r) +} + +func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqQueryKeys + json.NewDecoder(r.Body).Decode(&req) + resp := mautrix.RespQueryKeys{ + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + } + for user := range req.DeviceKeys { + resp.MasterKeys[user] = s.MasterKeys[user] + resp.UserSigningKeys[user] = s.UserSigningKeys[user] + resp.SelfSigningKeys[user] = s.SelfSigningKeys[user] + resp.DeviceKeys[user] = s.DeviceKeys[user] + } + json.NewEncoder(w).Encode(&resp) +} + +func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqUploadKeys + json.NewDecoder(r.Body).Decode(&req) + + userID := s.getUserID(r) + if _, ok := s.DeviceKeys[userID]; !ok { + s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} + } + s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys + + json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, + }) +} + +func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.UploadCrossSigningKeysReq + json.NewDecoder(r.Body).Decode(&req) + + userID := s.getUserID(r) + s.MasterKeys[userID] = req.Master + s.SelfSigningKeys[userID] = req.SelfSigning + s.UserSigningKeys[userID] = req.UserSigning + + s.emptyResp(w, r) +} + +func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { + t.Helper() + client, err := mautrix.NewClient(ms.URL, "", "") + require.NoError(t, err) + client.StateStore = mautrix.NewMemoryStateStore() + + _, err = client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: userID.String(), + }, + DeviceID: deviceID, + Password: "password", + StoreCredentials: true, + }) + require.NoError(t, err) + + cryptoStore := crypto.NewMemoryStore(nil) + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore) + require.NoError(t, err) + client.Crypto = cryptoHelper + + err = cryptoHelper.Init(ctx) + require.NoError(t, err) + + machineLog := log.Logger.With(). + Stringer("my_user_id", userID). + Stringer("my_device_id", deviceID). + Logger() + cryptoHelper.Machine().Log = &machineLog + + err = cryptoHelper.Machine().ShareKeys(ctx, 50) + require.NoError(t, err) + + return client, cryptoStore +} + +func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { + t.Helper() + + for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { + client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) + ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] + } +} + +func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { + err := cryptoStore.PutDevice(ctx, userID, &id.Device{ + UserID: userID, + DeviceID: deviceID, + }) + if err != nil { + panic(err) + } +} diff --git a/crypto/verificationhelper/qrcode.go b/crypto/verificationhelper/qrcode.go index 11698152..a28d8fc3 100644 --- a/crypto/verificationhelper/qrcode.go +++ b/crypto/verificationhelper/qrcode.go @@ -82,10 +82,6 @@ func NewQRCodeFromBytes(data []byte) (*QRCode, error) { // // [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format func (q *QRCode) Bytes() []byte { - if q == nil { - return nil - } - var buf bytes.Buffer buf.WriteString("MATRIX") // Header buf.WriteByte(0x02) // Version diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index d8827b8b..2ea0a0ed 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -32,18 +32,16 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by Stringer("transaction_id", qrCode.TransactionID). Int("mode", int(qrCode.Mode)). Logger() - ctx = log.WithContext(ctx) - vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, qrCode.TransactionID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", qrCode.TransactionID, err) - } else if txn.VerificationState != VerificationStateReady { + txn, ok := vh.activeTransactions[qrCode.TransactionID] + if !ok { + return fmt.Errorf("unknown transaction ID found in QR code") + } else if txn.VerificationState != verificationStateReady { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } - txn.VerificationState = VerificationStateTheirQRScanned + txn.VerificationState = verificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") @@ -55,9 +53,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by switch qrCode.Mode { case QRCodeModeCrossSigning: - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) } if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the master key we expected") @@ -72,7 +70,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -80,7 +78,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // means that we don't trust the key. Key1 is the master key public // key, and Key2 is what the other device thinks our device key is. - if vh.client.UserID != txn.TheirUserID { + if vh.client.UserID != txn.TheirUser { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } @@ -116,12 +114,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") } - if vh.client.UserID != txn.TheirUserID { + if vh.client.UserID != txn.TheirUser { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } @@ -142,7 +140,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } @@ -179,124 +177,103 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn.SentOurDone = true if txn.ReceivedTheirDone { log.Debug().Msg("We already received their done event. Setting verification state to done.") - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - return err - } - vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) - } else { - return vh.store.SaveVerificationTransaction(ctx, txn) + delete(vh.activeTransactions, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID) } return nil } -// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends -// the m.key.verification.done event to the other device for the given -// transaction ID. The transaction ID should be one received via the -// VerificationRequested callback in [RequiredCallbacks] or the -// [StartVerification] or [StartInRoomVerification] functions. +// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends the +// m.key.verification.done event to the other device. func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "confirm QR code scanned"). Stringer("transaction_id", txnID). Logger() - ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateOurQRScanned { + txn, ok := vh.activeTransactions[txnID] + if !ok { + log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") + return nil + } else if txn.VerificationState != verificationStateOurQRScanned { return fmt.Errorf("transaction is not in the scanned state") } log.Info().Msg("Confirming QR code scanned") - // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) - if err != nil { - return err - } + if txn.TheirUser == vh.client.UserID { + // Self-signing situation. Trust their device. - // Trust their device - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) - if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) - } + // Get their device + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if err != nil { + return err + } - if txn.TheirUserID == vh.client.UserID { - // Self-signing situation. - // - // If we have the cross-signing keys, then we need to sign their device - // using the self-signing key. Otherwise, they have the master private - // key, so we need to trust the master public key. + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + // Cross-sign their device with the self-signing key if vh.mach.CrossSigningKeys != nil { err = vh.mach.SignOwnDevice(ctx, theirDevice) if err != nil { - return fmt.Errorf("failed to sign our own new device: %w", err) - } - } else { - err = vh.mach.SignOwnMasterKey(ctx) - if err != nil { - return fmt.Errorf("failed to sign our own master key: %w", err) + return fmt.Errorf("failed to sign their device: %w", err) } } } else { // Cross-signing situation. Sign their master key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) } - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true if txn.ReceivedTheirDone { - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - return err - } - vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) - } else { - return vh.store.SaveVerificationTransaction(ctx, txn) + delete(vh.activeTransactions, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID) } return nil } -func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *VerificationTransaction) (*QRCode, error) { +func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() - ctx = log.WithContext(ctx) - - if !slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) || - !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) { - log.Info().Msg("Ignoring QR code generation request as reciprocating is not supported by both devices") - return nil, nil + if vh.showQRCode == nil { + log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") + return nil } else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes") - return nil, nil + return nil } ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 { - return nil, errors.New("failed to get own cross-signing master public key") + return errors.New("failed to get own cross-signing master public key") } ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey) if err != nil { - return nil, err + return err } mode := QRCodeModeCrossSigning - if vh.client.UserID == txn.TheirUserID { + if vh.client.UserID == txn.TheirUser { // This is a self-signing situation. if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted @@ -306,7 +283,7 @@ func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *Verificat } else { // This is a cross-signing situation. if !ownMasterKeyTrusted { - return nil, errors.New("cannot cross-sign other device when own master key is not trusted") + return errors.New("cannot cross-sign other device when own master key is not trusted") } mode = QRCodeModeCrossSigning } @@ -318,9 +295,9 @@ func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *Verificat key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other user's master signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return nil, err + return err } key2 = theirSigningKeys.MasterKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -328,9 +305,9 @@ func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *Verificat key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other device's key. - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { - return nil, err + return err } key2 = theirDevice.SigningKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyUntrusted: @@ -345,5 +322,6 @@ func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *Verificat qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret - return qrCode, nil + vh.showQRCode(ctx, txn.TransactionID, qrCode) + return nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index e6392c79..bf8c6050 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -13,7 +13,6 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" - "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -29,43 +28,41 @@ import ( "maunium.net/go/mautrix/id" ) -// StartSAS starts a SAS verification flow for the given transaction ID. The -// transaction ID should be one received via the VerificationRequested callback -// in [RequiredCallbacks] or the [StartVerification] or -// [StartInRoomVerification] functions. +// StartSAS starts a SAS verification flow. The transaction ID should be the +// transaction ID of a verification request that was received via the +// VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). - Str("verification_action", "start SAS"). + Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). Logger() - ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateReady { - return fmt.Errorf("transaction is not in ready state: %s", txn.VerificationState.String()) + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateReady { + return errors.New("transaction is not in ready state") } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } - txn.VerificationState = VerificationStateSASStarted + txn.VerificationState = verificationStateSASStarted txn.StartedByUs = true if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } // Ensure that we have their device key. - _, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { log.Err(err).Msg("Failed to fetch device") return err } log.Info().Msg("Sending start event") - startEventContent := event.VerificationStartEventContent{ + txn.StartEventContent = &event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, Method: event.VerificationMethodSAS, @@ -80,43 +77,35 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio event.SASMethodEmoji, }, } - if err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, &startEventContent); err != nil { - return err - } - txn.StartEventContent = &startEventContent - return vh.store.SaveVerificationTransaction(ctx, txn) + return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) } // ConfirmSAS indicates that the user has confirmed that the SAS matches SAS -// shown on the other user's device for the given transaction ID. The -// transaction ID should be one received via the VerificationRequested callback -// in [RequiredCallbacks] or the [StartVerification] or -// [StartInRoomVerification] functions. +// shown on the other user's device. func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "confirm SAS"). Stringer("transaction_id", txnID). Logger() - ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateSASKeysExchanged { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateSASKeysExchanged { return errors.New("transaction is not in keys exchanged state") } + var err error keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") - var masterKey string // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, myDeviceKeyID.String(), myDevice.SigningKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } @@ -124,9 +113,8 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Master signing key crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { - masterKey = crossSigningKeys.MasterKey.String() - crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), masterKey) + crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) if err != nil { return err } @@ -137,7 +125,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keyIDs = append(keyIDs, keyID.String()) } slices.Sort(keyIDs) - keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { return err } @@ -150,23 +138,17 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat if err != nil { return err } - log.Info().Msg("Sent our MAC event") txn.SentOurMAC = true if txn.ReceivedTheirMAC { - txn.VerificationState = VerificationStateSASMACExchanged - - if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { - return fmt.Errorf("failed to trust keys: %w", err) - } - + txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true } - return vh.store.SaveVerificationTransaction(ctx, txn) + return nil } // onVerificationStartSAS handles the m.key.verification.start events with @@ -174,13 +156,12 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). - Str("verification_action", "start SAS"). + Str("verification_action", "start_sas"). Stringer("transaction_id", txn.TransactionID). Logger() - ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification start event") _, err := vh.mach.GetOrFetchDevice(ctx, evt.Sender, startEvt.FromDevice) @@ -223,30 +204,29 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve return fmt.Errorf("failed to generate ephemeral key: %w", err) } txn.MACMethod = macMethod - txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} + txn.EphemeralKey = ephemeralKey + txn.StartEventContent = startEvt - if !txn.StartedByUs { - commitment, err := calculateCommitment(ephemeralKey.PublicKey(), txn) - if err != nil { - return fmt.Errorf("failed to calculate commitment: %w", err) - } - - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ - Commitment: commitment, - Hash: hashAlgorithm, - KeyAgreementProtocol: keyAggreementProtocol, - MessageAuthenticationCode: macMethod, - ShortAuthenticationString: sasMethods, - }) - if err != nil { - return fmt.Errorf("failed to send accept event: %w", err) - } - txn.VerificationState = VerificationStateSASAccepted + commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) + if err != nil { + return fmt.Errorf("failed to calculate commitment: %w", err) } - return vh.store.SaveVerificationTransaction(ctx, txn) + + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ + Commitment: commitment, + Hash: hashAlgorithm, + KeyAgreementProtocol: keyAggreementProtocol, + MessageAuthenticationCode: macMethod, + ShortAuthenticationString: sasMethods, + }) + if err != nil { + return fmt.Errorf("failed to send accept event: %w", err) + } + txn.VerificationState = verificationStateSASAccepted + return nil } -func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, txn VerificationTransaction) ([]byte, error) { +func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { // The commitmentHashInput is the hash (encoded as unpadded base64) of the // concatenation of the device's ephemeral public key (encoded as // unpadded base64) and the canonical JSON representation of the @@ -256,7 +236,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, txn VerificationTransa // hashing it, but we are just stuck on that. commitmentHashInput := sha256.New() commitmentHashInput.Write([]byte(base64.RawStdEncoding.EncodeToString(ephemeralPubKey.Bytes()))) - encodedStartEvt, err := json.Marshal(txn.StartEventContent) + encodedStartEvt, err := json.Marshal(startEvt) if err != nil { return nil, err } @@ -268,7 +248,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, txn VerificationTransa // event. This follows Step 4 of [Section 11.12.2.2] of the Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { acceptEvt := evt.Content.AsVerificationAccept() log := vh.getLog(ctx).With(). Str("verification_action", "accept"). @@ -279,12 +259,11 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri Str("message_authentication_code", string(acceptEvt.MessageAuthenticationCode)). Any("short_authentication_string", acceptEvt.ShortAuthenticationString). Logger() - ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification accept event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateSASStarted { + if txn.VerificationState != verificationStateSASStarted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received accept event for a transaction that is not in the started state") return @@ -304,49 +283,49 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri return } - txn.VerificationState = VerificationStateSASAccepted + txn.VerificationState = verificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment - txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} + txn.EphemeralKey = ephemeralKey txn.EphemeralPublicKeyShared = true - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() - ctx = log.WithContext(ctx) keyEvt := evt.Content.AsVerificationKey() vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateSASAccepted { + if txn.VerificationState != verificationStateSASAccepted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received key event for a transaction that is not in the accepted state") return } var err error - publicKey, err := ecdh.X25519().NewPublicKey(keyEvt.Key) + txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { log.Err(err).Msg("Failed to generate other public key") return } - txn.OtherPublicKey = &ECDHPublicKey{publicKey} if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(publicKey, txn) + commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return } if !bytes.Equal(commitment, txn.Commitment) { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "The key was not the one we expected") + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ + Code: event.VerificationCancelCodeKeyMismatch, + Reason: "The key was not the one we expected.", + }) + if err != nil { + log.Err(err).Msg("Failed to send cancellation event") + } return } } else { @@ -359,7 +338,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific } txn.EphemeralPublicKeyShared = true } - txn.VerificationState = VerificationStateSASKeysExchanged + txn.VerificationState = verificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { @@ -369,7 +348,6 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific var decimals []int var emojis []rune - var emojiDescriptions []string if slices.Contains(txn.StartEventContent.ShortAuthenticationString, event.SASMethodDecimal) { decimals = []int{ (int(sasBytes[0])<<5 | int(sasBytes[1])>>3) + 1000, @@ -385,18 +363,13 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific // Right shift the number and then mask the lowest 6 bits. emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0b111111 emojis = append(emojis, allEmojis[emojiIdx]) - emojiDescriptions = append(emojiDescriptions, allEmojiDescriptions[emojiIdx]) } } - vh.showSAS(ctx, txn.TransactionID, emojis, emojiDescriptions, decimals) - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } + vh.showSAS(ctx, txn.TransactionID, emojis, decimals) } -func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) +func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) if err != nil { return nil, err } @@ -411,8 +384,8 @@ func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ( }, "|") theirInfo := strings.Join([]string{ - txn.TheirUserID.String(), - txn.TheirDeviceID.String(), + txn.TheirUser.String(), + txn.TheirDevice.String(), base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), }, "|") @@ -485,8 +458,8 @@ func BrokenB64Encode(input []byte) string { return string(output) } -func (vh *VerificationHelper) verificationMACHKDF(txn VerificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) +func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) if err != nil { return nil, err } @@ -586,78 +559,10 @@ var allEmojis = []rune{ '📌', } -var allEmojiDescriptions = []string{ - "Dog", - "Cat", - "Lion", - "Horse", - "Unicorn", - "Pig", - "Elephant", - "Rabbit", - "Panda", - "Rooster", - "Penguin", - "Turtle", - "Fish", - "Octopus", - "Butterfly", - "Flower", - "Tree", - "Cactus", - "Mushroom", - "Globe", - "Moon", - "Cloud", - "Fire", - "Banana", - "Apple", - "Strawberry", - "Corn", - "Pizza", - "Cake", - "Heart", - "Smiley", - "Robot", - "Hat", - "Glasses", - "Spanner", - "Santa", - "Thumbs Up", - "Umbrella", - "Hourglass", - "Clock", - "Gift", - "Light Bulb", - "Book", - "Pencil", - "Paperclip", - "Scissors", - "Lock", - "Key", - "Hammer", - "Telephone", - "Flag", - "Train", - "Bicycle", - "Aeroplane", - "Rocket", - "Trophy", - "Ball", - "Guitar", - "Trumpet", - "Bell", - "Anchor", - "Headphones", - "Folder", - "Pin", -} - -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() - ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification MAC event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -666,19 +571,16 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific // Verifying Keys MAC log.Info().Msg("Verifying MAC for all sent keys") var hasTheirDeviceKey bool - var masterKey string var keyIDs []string for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() - if kID == txn.TheirDeviceID.String() { + if kID == txn.TheirDevice.String() { hasTheirDeviceKey = true - } else { - masterKey = kID } } slices.Sort(keyIDs) - expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return @@ -693,9 +595,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } // Verify the MAC for each key - var theirDevice *id.Device for keyID, mac := range macEvt.MAC { - log.Info().Stringer("key_id", keyID).Msg("Received MAC for key") + log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") alg, kID := keyID.Parse() if alg != id.KeyAlgorithmEd25519 { @@ -704,12 +605,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } var key string - if kID == txn.TheirDeviceID.String() { - if theirDevice != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInvalidMessage, "two keys found for their device ID") - return - } - theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + var theirDevice *id.Device + if kID == txn.TheirDevice.String() { + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return @@ -728,27 +626,31 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific key = crossSigningKeys.MasterKey.String() } - expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return } - if subtle.ConstantTimeCompare(expectedMAC, mac) == 0 { + if !bytes.Equal(expectedMAC, mac) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "MAC mismatch for key %s", keyID) return } + + // Trust their device + if kID == txn.TheirDevice.String() { + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) + return + } + } } log.Info().Msg("All MACs verified") txn.ReceivedTheirMAC = true if txn.SentOurMAC { - txn.VerificationState = VerificationStateSASMACExchanged - - if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to trust keys: %w", err) - return - } - + txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -756,57 +658,4 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } txn.SentOurDone = true } - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } -} - -func (vh *VerificationHelper) trustKeysAfterMACCheck(ctx context.Context, txn VerificationTransaction, masterKey string) error { - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) - if err != nil { - return fmt.Errorf("failed to fetch their device: %w", err) - } - // Trust their device - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) - if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) - } - - if txn.TheirUserID == vh.client.UserID { - // Self-signing situation. - // - // If we have the cross-signing keys, then we need to sign their device - // using the self-signing key. Otherwise, they have the master private - // key, so we need to trust the master public key. - if vh.mach.CrossSigningKeys != nil { - err = vh.mach.SignOwnDevice(ctx, theirDevice) - if err != nil { - return fmt.Errorf("failed to sign our own new device: %w", err) - } - } else { - err = vh.mach.SignOwnMasterKey(ctx) - if err != nil { - return fmt.Errorf("failed to sign our own master key: %w", err) - } - } - } else if masterKey != "" { - // Cross-signing situation. - // - // The master key was included in the list of keys to verify, so verify - // that it matches what we expect and sign their master key using the - // user-signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) - if err != nil { - return fmt.Errorf("couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) - } else if theirSigningKeys.MasterKey.String() != masterKey { - return fmt.Errorf("master keys do not match") - } - - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { - return fmt.Errorf("failed to sign %s's master key: %w", txn.TheirUserID, err) - } - } - return nil } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 0a781c16..e7ea53c5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -9,13 +9,12 @@ package verificationhelper import ( "bytes" "context" - "errors" + "crypto/ecdh" "fmt" "sync" "time" "github.com/rs/zerolog" - "go.mau.fi/util/exslices" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -26,33 +25,117 @@ import ( "maunium.net/go/mautrix/id" ) +type verificationState int + +const ( + verificationStateRequested verificationState = iota + verificationStateReady + + verificationStateTheirQRScanned // We scanned their QR code + verificationStateOurQRScanned // They scanned our QR code + + verificationStateSASStarted // An SAS verification has been started + verificationStateSASAccepted // An SAS verification has been accepted + verificationStateSASKeysExchanged // An SAS verification has exchanged keys + verificationStateSASMACExchanged // An SAS verification has exchanged MACs +) + +func (step verificationState) String() string { + switch step { + case verificationStateRequested: + return "requested" + case verificationStateReady: + return "ready" + case verificationStateTheirQRScanned: + return "their_qr_scanned" + case verificationStateOurQRScanned: + return "our_qr_scanned" + case verificationStateSASStarted: + return "sas_started" + case verificationStateSASAccepted: + return "sas_accepted" + case verificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case verificationStateSASMACExchanged: + return "sas_mac" + default: + return fmt.Sprintf("verificationStep(%d)", step) + } +} + +type verificationTransaction struct { + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID + + // VerificationState is the current step of the verification flow. + VerificationState verificationState + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID + + // TheirDevice is the device ID of the device that either made the initial + // request or accepted our request. + TheirDevice id.DeviceID + // TheirUser is the user ID of the other user. + TheirUser id.UserID + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte + + StartedByUs bool // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content + Commitment []byte // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod // The method used to calculate the MAC + EphemeralKey *ecdh.PrivateKey // The ephemeral key + EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared + OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key + ReceivedTheirMAC bool // Whether we have received their MAC + SentOurMAC bool // Whether we have sent our MAC + ReceivedTheirDone bool // Whether we have received their done event + SentOurDone bool // Whether we have sent our done event +} + // RequiredCallbacks is an interface representing the callbacks required for // the [VerificationHelper]. type RequiredCallbacks interface { // VerificationRequested is called when a verification request is received // from another device. - VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) - - // VerificationReady is called when a verification request has been - // accepted by both parties. - VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) + VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) // VerificationDone is called when the verification is done. - VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) + VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) } type ShowSASCallbacks interface { // ShowSAS is a callback that is called when the SAS verification has // generated a short authentication string to show. It is guaranteed that - // either the emojis and emoji descriptions lists, or the decimals list, or - // both will be present. - ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) + // either the emojis list, or the decimals list, or both will be present. + ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) } type ShowQRCodeCallbacks interface { + // ScanQRCode is called when another device has sent a + // m.key.verification.ready event and indicated that they are capable of + // showing a QR code. + ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) + + // ShowQRCode is called when the verification has been accepted and a QR + // code should be shown to the user. + ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) + // QRCodeScanned is called when the other user has scanned the QR code and // sent the m.key.verification.start event. QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) @@ -62,80 +145,67 @@ type VerificationHelper struct { client *mautrix.Client mach *crypto.OlmMachine - store VerificationStore + activeTransactions map[id.VerificationTransactionID]*verificationTransaction activeTransactionsLock sync.Mutex // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod - verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) - verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) + verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - verificationDone func(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) + verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) - // showSAS is a callback that will be called after the SAS verification - // dance is complete and we want the client to show the emojis/decimals - showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) - // qrCodeScanned is a callback that will be called when the other device - // scanned the QR code we are showing - qrCodeScanned func(ctx context.Context, txnID id.VerificationTransactionID) + showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + + scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID) + showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) + qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) } var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan, supportsSAS bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } - if store == nil { - store = NewInMemoryVerificationStore() - } - helper := VerificationHelper{ - client: client, - mach: mach, - store: store, + client: client, + mach: mach, + activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, } if c, ok := callbacks.(RequiredCallbacks); !ok { panic("callbacks must implement RequiredCallbacks") } else { helper.verificationRequested = c.VerificationRequested - helper.verificationReady = c.VerificationReady helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } - if supportsSAS { - if c, ok := callbacks.(ShowSASCallbacks); !ok { - panic("callbacks must implement showSAS if supportsSAS is true") - } else { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) - helper.showSAS = c.ShowSAS - } + supportedMethods := map[event.VerificationMethod]struct{}{} + if c, ok := callbacks.(ShowSASCallbacks); ok { + supportedMethods[event.VerificationMethodSAS] = struct{}{} + helper.showSAS = c.ShowSAS } - if supportsQRShow { - if c, ok := callbacks.(ShowQRCodeCallbacks); !ok { - panic("callbacks must implement ShowQRCodeCallbacks if supportsQRShow is true") - } else { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) - helper.qrCodeScanned = c.QRCodeScanned - } + if c, ok := callbacks.(ShowQRCodeCallbacks); ok { + supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + helper.scanQRCode = c.ScanQRCode + helper.showQRCode = c.ShowQRCode + helper.qrCodeScaned = c.QRCodeScanned } - if supportsQRScan { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan) - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) + if supportsScan { + supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} } - helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods) + + helper.supportedMethods = maps.Keys(supportedMethods) return &helper } func (vh *VerificationHelper) getLog(ctx context.Context) *zerolog.Logger { logger := zerolog.Ctx(ctx).With(). Str("component", "verification"). - Stringer("device_id", vh.client.DeviceID). - Stringer("user_id", vh.client.UserID). Any("supported_methods", vh.supportedMethods). Logger() return &logger @@ -163,16 +233,14 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // Wrapper for the event handlers to check that the transaction ID is known // and ignore the event if it isn't. - wrapHandler := func(callback func(context.Context, VerificationTransaction, *event.Event)) func(context.Context, *event.Event) { + wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { return func(ctx context.Context, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "check transaction ID"). Stringer("sender", evt.Sender). Stringer("room_id", evt.RoomID). Stringer("event_id", evt.ID). - Stringer("event_type", evt.Type). Logger() - ctx = log.WithContext(ctx) var transactionID id.VerificationTransactionID if evt.ID != "" { @@ -188,12 +256,8 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() - txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) - if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { - log.Err(err).Msg("failed to get verification transaction") - vh.activeTransactionsLock.Unlock() - return - } else if errors.Is(err, ErrUnknownVerificationTransaction) { + txn, ok := vh.activeTransactions[transactionID] + if !ok { // If it's a cancellation event for an unknown transaction, we // can just ignore it. if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { @@ -202,14 +266,11 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { return } - log.Warn().Msg("Sending cancellation event for unknown transaction ID") - // We have to create a fake transaction so that the call to - // cancelVerificationTxn works. - txn = VerificationTransaction{ - ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, - RoomID: evt.RoomID, - TheirUserID: evt.Sender, + // verificationCancelled works. + txn = &verificationTransaction{ + RoomID: evt.RoomID, + TheirUser: evt.Sender, } if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { txn.TransactionID = transactionable.GetTransactionID() @@ -217,7 +278,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { txn.TransactionID = id.VerificationTransactionID(evt.ID) } if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDeviceID = id.DeviceID(fromDevice.(string)) + txn.TheirDevice = id.DeviceID(fromDevice.(string)) } // Send a cancellation event. @@ -258,11 +319,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS - allTransactions, err := vh.store.GetAllVerificationTransactions(ctx) - for _, txn := range allTransactions { - vh.expireTransactionAt(txn.TransactionID, txn.ExpirationTime.Time) - } - return err + return nil } // StartVerification starts an interactive verification flow with the given @@ -287,14 +344,12 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI } } - log := vh.getLog(ctx).With(). + vh.getLog(ctx).Info(). Str("verification_action", "start verification"). Stringer("transaction_id", txnID). Stringer("to", to). Any("device_ids", maps.Keys(devices)). - Logger() - ctx = log.WithContext(ctx) - log.Info().Msg("Sending verification request") + Msg("Sending verification request") now := time.Now() content := &event.Content{ @@ -324,13 +379,13 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ - ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, - VerificationState: VerificationStateRequested, + vh.activeTransactions[txnID] = &verificationTransaction{ + VerificationState: verificationStateRequested, TransactionID: txnID, - TheirUserID: to, + TheirUser: to, SentToDeviceIDs: maps.Keys(devices), - }) + } + return txnID, nil } // StartInRoomVerification starts an interactive verification flow with the @@ -341,12 +396,11 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI Stringer("room_id", roomID). Stringer("to", to). Logger() - ctx = log.WithContext(ctx) log.Info().Msg("Sending verification request") content := event.MessageEventContent{ MsgType: event.MsgVerificationRequest, - Body: fmt.Sprintf("%s is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", vh.client.UserID), + Body: "Alice is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", FromDevice: vh.client.DeviceID, Methods: vh.supportedMethods, To: to, @@ -365,32 +419,28 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ - ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, + vh.activeTransactions[txnID] = &verificationTransaction{ RoomID: roomID, - VerificationState: VerificationStateRequested, + VerificationState: verificationStateRequested, TransactionID: txnID, - TheirUserID: to, - }) + TheirUser: to, + } + return txnID, nil } // AcceptVerification accepts a verification request. The transaction ID should // be the transaction ID of a verification request that was received via the // VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error { - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() - log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). Logger() - ctx = log.WithContext(ctx) - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return err - } else if txn.VerificationState != VerificationStateRequested { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } @@ -414,54 +464,34 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } } - log.Info().Any("methods", maps.Keys(supportedMethods)).Msg("Sending ready event") + log.Info().Msg("Sending ready event") readyEvt := &event.VerificationReadyEventContent{ FromDevice: vh.client.DeviceID, Methods: maps.Keys(supportedMethods), } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { return err } - txn.VerificationState = VerificationStateReady + txn.VerificationState = verificationStateReady - supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) - supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) - supportsScanQRCode := supportsReciprocate && - slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) - - qrCode, err := vh.generateQRCode(ctx, &txn) - if err != nil { - return err + if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + vh.scanQRCode(ctx, txn.TransactionID) } - vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) - return vh.store.SaveVerificationTransaction(ctx, txn) + + return vh.generateAndShowQRCode(ctx, txn) } -// DismissVerification dismisses the verification request with the given -// transaction ID. The transaction ID should be one received via the -// VerificationRequested callback in [RequiredCallbacks] or the -// [StartVerification] or [StartInRoomVerification] functions. -func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() - return vh.store.DeleteVerification(ctx, txnID) -} - -// DismissVerification cancels the verification request with the given -// transaction ID. The transaction ID should be one received via the -// VerificationRequested callback in [RequiredCallbacks] or the -// [StartVerification] or [StartInRoomVerification] functions. +// CancelVerification cancels a verification request. The transaction ID should +// be the transaction ID of a verification request that was received via the +// VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error { vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return err + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") } log := vh.getLog(ctx).With(). Str("verification_action", "cancel verification"). @@ -482,28 +512,29 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V } else { cancelEvt.SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUserID: {}, + txn.TheirUser: {}, }} - if len(txn.TheirDeviceID) > 0 { + if len(txn.TheirDevice) > 0 { // Send the cancellation event to only the device that accepted the // verification request. All of the other devices already received a // cancellation event with code "m.acceped". - req.Messages[txn.TheirUserID][txn.TheirDeviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUser][txn.TheirDevice] = &event.Content{Parsed: cancelEvt} } else { // Send the cancellation event to all of the devices that we sent the // request to. for _, deviceID := range txn.SentToDeviceIDs { if deviceID != vh.client.DeviceID { - req.Messages[txn.TheirUserID][deviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUser][deviceID] = &event.Content{Parsed: cancelEvt} } } } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { - return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUserID]), err) + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) } } - return vh.store.DeleteVerification(ctx, txn.TransactionID) + delete(vh.activeTransactions, txn.TransactionID) + return nil } // sendVerificationEvent sends a verification event to the other user's device @@ -515,7 +546,7 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V // [event.VerificationTransactionable]. // - evtType can be either the to-device or in-room version of the event type // as it is always stringified. -func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn VerificationTransaction, evtType event.Type, content any) error { +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { if txn.RoomID != "" { content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ @@ -527,13 +558,13 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUserID: { - txn.TheirDeviceID: &event.Content{Parsed: content}, + txn.TheirUser: { + txn.TheirDevice: &event.Content{Parsed: content}, }, }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDeviceID, err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDevice, err) } } return nil @@ -545,24 +576,21 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver // directly to expose the error to its caller). // // Must always be called with the activeTransactionsLock held. -func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { + log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() - log := vh.getLog(ctx).With(). + log.Info(). Stringer("transaction_id", txn.TransactionID). Str("code", string(code)). Str("reason", reason). - Logger() - ctx = log.WithContext(ctx) - log.Info().Msg("Sending cancellation event") + Msg("Sending cancellation event") cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) if err != nil { log.Err(err).Msg("failed to send cancellation event") return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("deleting verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } @@ -641,67 +669,62 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev } vh.activeTransactionsLock.Lock() - newTxn := VerificationTransaction{ - ExpirationTime: jsontime.UnixMilli{Time: verificationRequest.Timestamp.Add(time.Minute * 10)}, + newTxn := &verificationTransaction{ RoomID: evt.RoomID, - VerificationState: VerificationStateRequested, + VerificationState: verificationStateRequested, TransactionID: verificationRequest.TransactionID, - TheirDeviceID: verificationRequest.FromDevice, - TheirUserID: evt.Sender, + TheirDevice: verificationRequest.FromDevice, + TheirUser: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } - if txn, err := vh.store.FindVerificationTransactionForUserDevice(ctx, evt.Sender, verificationRequest.FromDevice); err != nil && !errors.Is(err, ErrUnknownVerificationTransaction) { - log.Err(err).Stringer("sender", evt.Sender).Stringer("device_id", verificationRequest.FromDevice).Msg("failed to find verification transaction") - vh.activeTransactionsLock.Unlock() - return - } else if !errors.Is(err, ErrUnknownVerificationTransaction) { - if txn.TransactionID == verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") - } else { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + for existingTxnID, existingTxn := range vh.activeTransactions { + if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return + } + + if existingTxnID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return } - vh.activeTransactionsLock.Unlock() - return - } - if err := vh.store.SaveVerificationTransaction(ctx, newTxn); err != nil { - log.Err(err).Msg("failed to save verification transaction") } + vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) - vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender, verificationRequest.FromDevice) + vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) } -func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expireAt time.Time) { go func() { - time.Sleep(time.Until(expiresAt)) + time.Sleep(time.Until(expireAt)) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(context.Background(), txnID) - if err == ErrUnknownVerificationTransaction { - // Already deleted, nothing to expire + txn, ok := vh.activeTransactions[txnID] + if !ok { return - } else if err != nil { - vh.getLog(context.Background()).Err(err).Msg("failed to get verification transaction to expire") - } else { - vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") } + + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") }() } -func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() - ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateRequested { + if txn.VerificationState != verificationStateRequested { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } @@ -709,15 +732,10 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationState = VerificationStateReady - txn.TheirDeviceID = readyEvt.FromDevice + txn.VerificationState = verificationStateReady + txn.TheirDevice = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods - log.Info(). - Stringer("their_device_id", txn.TheirDeviceID). - Any("their_supported_methods", txn.TheirSupportedMethods). - Msg("Received verification ready event") - // If we sent this verification request, send cancellations to all of the // other devices. if len(txn.SentToDeviceIDs) > 0 { @@ -728,60 +746,50 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif Reason: "The verification was accepted on another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} - for _, deviceID := range txn.SentToDeviceIDs { - if deviceID == txn.TheirDeviceID || deviceID == vh.client.DeviceID { + devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %w", txn.TheirUser, err) + return + } + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + for deviceID := range devices { + if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this // is a self-verification). continue } - req.Messages[txn.TheirUserID][deviceID] = content + req.Messages[txn.TheirUser][deviceID] = content } - _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { log.Warn().Err(err).Msg("Failed to send cancellation requests") } } - supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) - supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) - supportsScanQRCode := supportsReciprocate && - slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) - - qrCode, err := vh.generateQRCode(ctx, &txn) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate QR code: %w", err) - return + if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + vh.scanQRCode(ctx, txn.TransactionID) } - vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err) + err := vh.generateAndShowQRCode(ctx, txn) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %w", err) } } -func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). - Stringer("their_device_id", txn.TheirDeviceID). - Any("their_supported_methods", txn.TheirSupportedMethods). - Bool("started_by_us", txn.StartedByUs). Logger() ctx = log.WithContext(ctx) - log.Info().Msg("Received verification start event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == VerificationStateSASStarted || txn.VerificationState == VerificationStateOurQRScanned || txn.VerificationState == VerificationStateTheirQRScanned { + if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { // We might have sent the event, and they also sent an event. if txn.StartEventContent == nil || !txn.StartedByUs { // We didn't sent a start event yet, so we have gotten ourselves @@ -813,58 +821,49 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif return } - if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { - log.Debug().Msg("Using their start event instead of ours because they are alphabetically before us") + if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { + // Use their start event instead of ours txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != VerificationStateReady { + } else if txn.VerificationState != verificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return - } else { - txn.StartEventContent = startEvt } switch startEvt.Method { case event.VerificationMethodSAS: - log.Info().Msg("Received SAS start event") - txn.VerificationState = VerificationStateSASStarted + txn.VerificationState = verificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") - if !bytes.Equal(txn.QRCodeSharedSecret, txn.StartEventContent.Secret) { + if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } - txn.VerificationState = VerificationStateOurQRScanned - vh.qrCodeScanned(ctx, txn.TransactionID) - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } + txn.VerificationState = verificationStateOurQRScanned + vh.qrCodeScaned(ctx, txn.TransactionID) default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. - log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event") - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unknown method %s", txn.StartEventContent.Method) + log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", startEvt.Method)) } } -func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn VerificationTransaction, evt *event.Event) { - log := vh.getLog(ctx).With(). +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + vh.getLog(ctx).Info(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). - Bool("sent_our_done", txn.SentOurDone). - Logger() - ctx = log.WithContext(ctx) - log.Info().Msg("Verification done") + Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if !slices.Contains([]VerificationState{ - VerificationStateTheirQRScanned, VerificationStateOurQRScanned, VerificationStateSASMACExchanged, + if !slices.Contains([]verificationState{ + verificationStateTheirQRScanned, verificationStateOurQRScanned, verificationStateSASMACExchanged, }, txn.VerificationState) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return @@ -872,61 +871,21 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi txn.ReceivedTheirDone = true if txn.SentOurDone { - if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("Delete verification failed") - } - vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) - } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") + delete(vh.activeTransactions, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID) } } -func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() - log := vh.getLog(ctx).With(). + vh.getLog(ctx).Info(). Str("verification_action", "cancel"). Stringer("transaction_id", txn.TransactionID). Str("cancel_code", string(cancelEvt.Code)). Str("reason", cancelEvt.Reason). - Logger() - ctx = log.WithContext(ctx) - log.Info().Msg("Verification was cancelled") + Msg("Verification was cancelled") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - - // Element (and at least the old desktop client) send cancellation events - // when the user rejects the verification request. This is really dumb, - // because they should just instead ignore the request and not send a - // cancellation. - // - // The above behavior causes a problem with the other devices that we sent - // the verification request to because they don't know that the request was - // cancelled. - // - // As a workaround, if we receive a cancellation event to a transaction - // that is currently in the REQUESTED state, then we will send - // cancellations to all of the devices that we sent the request to. This - // will ensure that all of the clients know that the request was cancelled. - if txn.VerificationState == VerificationStateRequested && len(txn.SentToDeviceIDs) > 0 { - content := &event.Content{ - Parsed: &event.VerificationCancelEventContent{ - ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, - Code: event.VerificationCancelCodeUser, - Reason: "The verification was rejected from another device.", - }, - } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} - for _, deviceID := range txn.SentToDeviceIDs { - req.Messages[txn.TheirUserID][deviceID] = content - } - _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) - if err != nil { - log.Warn().Err(err).Msg("Failed to send cancellation requests") - } - } - - if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("Delete verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index 5e3f146b..2bbed25e 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -11,7 +11,7 @@ import ( "fmt" "testing" - "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/rs/zerolog/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,6 +32,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -50,10 +51,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, bobUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -82,7 +83,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -97,7 +98,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -120,7 +121,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -135,7 +136,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index ea918cd4..443157b7 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -11,7 +11,7 @@ import ( "fmt" "testing" - "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/rs/zerolog/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +36,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -61,7 +62,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) if tc.expectedAcceptError != "" { @@ -71,7 +72,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { require.NoError(t, err) } - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -134,6 +135,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -150,10 +152,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -182,7 +184,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -197,7 +199,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -220,7 +222,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -235,7 +237,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. @@ -249,6 +251,7 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -260,10 +263,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -275,12 +278,12 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // Emulate scanning the QR code shown by the receiving device // on the sending device. err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID") + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") // Emulate scanning the QR code shown by the sending device on // the receiving device. err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID") + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") } func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { @@ -307,6 +310,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -323,10 +327,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -344,7 +348,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the receiving device received a cancellation. receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) cancellation := receivingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) @@ -358,7 +362,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the sending device received a cancellation. sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] assert.Len(t, sendingInbox, 1) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) cancellation := sendingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 283eca84..e986cf85 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -5,7 +5,7 @@ import ( "fmt" "testing" - "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/rs/zerolog/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" @@ -36,6 +36,7 @@ func TestVerification_SAS(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -59,10 +60,10 @@ func TestVerification_SAS(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Test that the start event is correct var startEvt *event.VerificationStartEventContent @@ -101,7 +102,7 @@ func TestVerification_SAS(t *testing.T) { if tc.sendingStartsSAS { // Process the verification start event on the receiving // device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Receiving device sent the accept event to the sending device sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] @@ -109,7 +110,7 @@ func TestVerification_SAS(t *testing.T) { acceptEvt = sendingInbox[0].Content.AsVerificationAccept() } else { // Process the verification start event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Sending device sent the accept event to the receiving device receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] @@ -128,7 +129,7 @@ func TestVerification_SAS(t *testing.T) { var firstKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the verification accept event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Sending device sends first key event to the receiving // device. @@ -138,7 +139,7 @@ func TestVerification_SAS(t *testing.T) { } else { // Process the verification accept event on the receiving // device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Receiving device sends first key event to the sending // device. @@ -154,7 +155,7 @@ func TestVerification_SAS(t *testing.T) { var secondKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the first key event on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Receiving device sends second key event to the sending // device. @@ -164,12 +165,10 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the receiving device showed emojis and SAS numbers. assert.Len(t, receivingCallbacks.GetDecimalsShown(txnID), 3) - emojis, descriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) - assert.Len(t, emojis, 7) - assert.Len(t, descriptions, 7) + assert.Len(t, receivingCallbacks.GetEmojisShown(txnID), 7) } else { // Process the first key event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Sending device sends second key event to the receiving // device. @@ -179,9 +178,7 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the sending device showed emojis and SAS numbers. assert.Len(t, sendingCallbacks.GetDecimalsShown(txnID), 3) - emojis, descriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) - assert.Len(t, emojis, 7) - assert.Len(t, descriptions, 7) + assert.Len(t, sendingCallbacks.GetEmojisShown(txnID), 7) } assert.Equal(t, txnID, secondKeyEvt.TransactionID) assert.NotEmpty(t, secondKeyEvt.Key) @@ -190,16 +187,13 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the SAS codes are the same. if tc.sendingStartsSAS { // Process the second key event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } else { // Process the second key event on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) } assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) - sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) - receivingEmojis, receivingDescriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) - assert.Equal(t, sendingEmojis, receivingEmojis) - assert.Equal(t, sendingDescriptions, receivingDescriptions) + assert.Equal(t, sendingCallbacks.GetEmojisShown(txnID), receivingCallbacks.GetEmojisShown(txnID)) // Test that the first MAC event is correct var firstMACEvt *event.VerificationMACEventContent @@ -273,88 +267,12 @@ func TestVerification_SAS(t *testing.T) { // Test the transaction is done on both sides. We have to dispatch // twice to process and drain all of the events. - ts.DispatchToDevice(t, ctx, sendingClient) - ts.DispatchToDevice(t, ctx, receivingClient) - ts.DispatchToDevice(t, ctx, sendingClient) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) }) } } - -func TestVerification_SAS_BothCallStart(t *testing.T) { - ctx := log.Logger.WithContext(context.TODO()) - - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) - var err error - - var sendingRecoveryKey string - var sendingCrossSigningKeysCache *crypto.CrossSigningKeysCache - - sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - require.NoError(t, err) - assert.NotEmpty(t, sendingRecoveryKey) - assert.NotNil(t, sendingCrossSigningKeysCache) - - // Send the verification request from the sender device and accept - // it on the receiving device and receive the verification ready - // event on the sending device. - txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) - require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) - err = receivingHelper.AcceptVerification(ctx, txnID) - require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) - - err = sendingHelper.StartSAS(ctx, txnID) - require.NoError(t, err) - - err = receivingHelper.StartSAS(ctx, txnID) - require.NoError(t, err) - - // Ensure that both devices have received the verification start event. - receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] - assert.Len(t, receivingInbox, 1) - assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) - sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] - assert.Len(t, sendingInbox, 1) - assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID) - - // Process the start event from the receiving client to the sending client. - ts.DispatchToDevice(t, ctx, sendingClient) - receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID] - assert.Len(t, receivingInbox, 2) - assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) - assert.Equal(t, txnID, receivingInbox[1].Content.AsVerificationAccept().TransactionID) - - // Process the rest of the events until we need to confirm the SAS. - for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 { - ts.DispatchToDevice(t, ctx, receivingClient) - ts.DispatchToDevice(t, ctx, sendingClient) - } - - // Confirm the SAS only the receiving device. - receivingHelper.ConfirmSAS(ctx, txnID) - ts.DispatchToDevice(t, ctx, sendingClient) - - // Verification is not done until both devices confirm the SAS. - assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) - assert.False(t, receivingCallbacks.IsVerificationDone(txnID)) - - // Now, confirm it on the sending device. - sendingHelper.ConfirmSAS(ctx, txnID) - - // Dispatching the events to the receiving device should get us to the done - // state on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) - assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) - assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) - - // Dispatching the events to the sending client should get us to the done - // state on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) - assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) - assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) -} diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index ce5ec5b4..e8be5771 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -2,14 +2,13 @@ package verificationhelper_test import ( "context" - "database/sql" "fmt" "os" "testing" "time" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/rs/zerolog/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,7 +18,6 @@ import ( "maunium.net/go/mautrix/crypto/verificationhelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/mockserver" ) var aliceUserID = id.UserID("@alice:example.org") @@ -32,19 +30,9 @@ func init() { zerolog.DefaultContextLogger = &log.Logger } -func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { - err := cryptoStore.PutDevice(ctx, userID, &id.Device{ - UserID: userID, - DeviceID: deviceID, - }) - if err != nil { - panic(err) - } -} - -func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = mockserver.Create(t) + ts = createMockServer(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -58,9 +46,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserv return } -func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = mockserver.Create(t) + ts = createMockServer(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -77,20 +65,11 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserv func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - senderVerificationDB, err := sql.Open("sqlite3", ":memory:") - require.NoError(t, err) - senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) - require.NoError(t, err) - - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receiverVerificationDB, err := sql.Open("sqlite3", ":memory:") - require.NoError(t, err) - receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) - require.NoError(t, err) - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -100,41 +79,32 @@ func TestVerification_Start(t *testing.T) { receivingDeviceID2 := id.DeviceID("receiving2") testCases := []struct { - supportsShow bool supportsScan bool - supportsSAS bool callbacks MockVerificationCallbacks startVerificationErrMsg string expectedVerificationMethods []event.VerificationMethod }{ - {false, false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {false, true, false, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - - {false, false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, - {true, false, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - - {false, false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {false, true, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - - {false, false, false, newAllVerificationCallbacks(), "no supported verification methods", nil}, - {false, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {false, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {true, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, } for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts := mockserver.Create(t) + ts := createMockServer(t) + defer ts.Close() client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan, tc.supportsSAS) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -144,7 +114,7 @@ func TestVerification_Start(t *testing.T) { return } - require.NoError(t, err) + assert.NoError(t, err) assert.NotEmpty(t, txnID) toDeviceInbox := ts.DeviceInbox[aliceUserID] @@ -158,7 +128,7 @@ func TestVerification_Start(t *testing.T) { assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) - require.Len(t, toDeviceInbox[receivingDeviceID], 1) + assert.Len(t, toDeviceInbox[receivingDeviceID], 1) // Ensure that the verification request is correct. verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() @@ -171,21 +141,13 @@ func TestVerification_Start(t *testing.T) { func TestVerification_StartThenCancel(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - bystanderDeviceID := id.DeviceID("bystander") for _, sendingCancels := range []bool{true, false} { t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { - ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) - bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) - bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true, true) - require.NoError(t, bystanderHelper.Init(ctx)) - - require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) - require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) - txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) @@ -195,15 +157,9 @@ func TestVerification_StartThenCancel(t *testing.T) { receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) - // Process the request event on the bystander device. - bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] - assert.Len(t, bystanderInbox, 1) - assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) - ts.DispatchToDevice(t, ctx, bystanderClient) - - // Cancel the verification request. + // Cancel the verification request on the sending device. var cancelEvt *event.VerificationCancelEventContent if sendingCancels { err = sendingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") @@ -215,11 +171,6 @@ func TestVerification_StartThenCancel(t *testing.T) { // Ensure that the cancellation event was sent to the receiving device. assert.Len(t, ts.DeviceInbox[aliceUserID][receivingDeviceID], 1) cancelEvt = ts.DeviceInbox[aliceUserID][receivingDeviceID][0].Content.AsVerificationCancel() - - // Ensure that the cancellation event was sent to the bystander device. - assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) - bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() - assert.Equal(t, cancelEvt, bystanderCancelEvt) } else { err = receivingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") assert.NoError(t, err) @@ -230,25 +181,10 @@ func TestVerification_StartThenCancel(t *testing.T) { // Ensure that the cancellation event was sent to the sending device. assert.Len(t, ts.DeviceInbox[aliceUserID][sendingDeviceID], 1) cancelEvt = ts.DeviceInbox[aliceUserID][sendingDeviceID][0].Content.AsVerificationCancel() - - // The bystander device should not have a cancellation event. - assert.Empty(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID]) } assert.Equal(t, txnID, cancelEvt.TransactionID) assert.Equal(t, event.VerificationCancelCodeUser, cancelEvt.Code) assert.Equal(t, "Recovery code preferred", cancelEvt.Reason) - - if !sendingCancels { - // Process the cancellation event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) - - // Ensure that the cancellation event was sent to the bystander device. - assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) - bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() - assert.Equal(t, txnID, bystanderCancelEvt.TransactionID) - assert.Equal(t, event.VerificationCancelCodeUser, bystanderCancelEvt.Code) - assert.Equal(t, "The verification was rejected from another device.", bystanderCancelEvt.Reason) - } }) } } @@ -256,7 +192,8 @@ func TestVerification_StartThenCancel(t *testing.T) { func TestVerification_Accept_NoSupportedMethods(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts := mockserver.Create(t) + ts := createMockServer(t) + defer ts.Close() sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID) @@ -269,12 +206,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true, true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -282,7 +219,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, txnID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiver ignored the request because it // doesn't support any of the verification methods in the @@ -295,44 +232,33 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { testCases := []struct { sendingSupportsScan bool - sendingSupportsShow bool receivingSupportsScan bool - receivingSupportsShow bool - sendingSupportsSAS bool - receivingSupportsSAS bool sendingCallbacks MockVerificationCallbacks receivingCallbacks MockVerificationCallbacks expectedVerificationMethods []event.VerificationMethod }{ - // TODO - {false, false, false, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, true, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - - {true, false, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {false, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, false, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {false, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, true, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {true, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, - - {true, true, true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, } for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") assert.NoError(t, err) assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan, tc.sendingSupportsSAS) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan, tc.receivingSupportsSAS) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -340,7 +266,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { require.NoError(t, err) // Process the verification request on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device received a verification // request with the correct transaction ID. @@ -350,13 +276,16 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - // Ensure that the receiving device get a notification about the - // transaction being ready. - assert.Contains(t, tc.receivingCallbacks.GetVerificationsReadyTransactions(), txnID) + _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) + _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) + sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks + _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) + _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) + receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks // Ensure that if the receiving device should show a QR code that // it has the correct content. - if tc.sendingSupportsScan && tc.receivingSupportsShow { + if tc.sendingSupportsScan && receivingCanShowQR { receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) assert.Equal(t, txnID, receivingShownQRCode.TransactionID) @@ -365,7 +294,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Check for whether the receiving device should be scanning a QR // code. - if tc.receivingSupportsScan && tc.sendingSupportsShow { + if tc.receivingSupportsScan && sendingCanShowQR { assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) } @@ -380,15 +309,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Receive the m.key.verification.ready event on the sending // device. - ts.DispatchToDevice(t, ctx, sendingClient) - - // Ensure that the sending device got a notification about the - // transaction being ready. - assert.Contains(t, tc.sendingCallbacks.GetVerificationsReadyTransactions(), txnID) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that if the sending device should show a QR code that it // has the correct content. - if tc.receivingSupportsScan && tc.sendingSupportsShow { + if tc.receivingSupportsScan && sendingCanShowQR { sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, sendingShownQRCode) assert.Equal(t, txnID, sendingShownQRCode.TransactionID) @@ -397,7 +322,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Check for whether the sending device should be scanning a QR // code. - if tc.sendingSupportsScan && tc.receivingSupportsShow { + if tc.sendingSupportsScan && receivingCanShowQR { assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) } }) @@ -409,6 +334,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) nonParticipatingDeviceID1 := id.DeviceID("non-participating1") @@ -425,12 +351,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { // the receiving device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) // Receive the m.key.verification.ready event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // The sending and receiving devices should not have any cancellation // events in their inboxes. @@ -450,6 +376,7 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { func TestVerification_ErrorOnDoubleAccept(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -457,7 +384,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -477,6 +404,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { func TestVerification_CancelOnDoubleStart(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -485,15 +413,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { // Send and accept the first verification request. txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID1) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event + ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event // Send a second verification request txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the sending device received a cancellation event for both of // the ongoing transactions. @@ -511,7 +439,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2)) - ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events + ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2)) } diff --git a/crypto/verificationhelper/verificationstore.go b/crypto/verificationhelper/verificationstore.go deleted file mode 100644 index 1eb8f752..00000000 --- a/crypto/verificationhelper/verificationstore.go +++ /dev/null @@ -1,159 +0,0 @@ -package verificationhelper - -import ( - "context" - "errors" - "fmt" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ErrUnknownVerificationTransaction = errors.New("unknown transaction ID") - -type VerificationState int - -const ( - VerificationStateRequested VerificationState = iota - VerificationStateReady - - VerificationStateTheirQRScanned // We scanned their QR code - VerificationStateOurQRScanned // They scanned our QR code - - VerificationStateSASStarted // An SAS verification has been started - VerificationStateSASAccepted // An SAS verification has been accepted - VerificationStateSASKeysExchanged // An SAS verification has exchanged keys - VerificationStateSASMACExchanged // An SAS verification has exchanged MACs -) - -func (step VerificationState) String() string { - switch step { - case VerificationStateRequested: - return "requested" - case VerificationStateReady: - return "ready" - case VerificationStateTheirQRScanned: - return "their_qr_scanned" - case VerificationStateOurQRScanned: - return "our_qr_scanned" - case VerificationStateSASStarted: - return "sas_started" - case VerificationStateSASAccepted: - return "sas_accepted" - case VerificationStateSASKeysExchanged: - return "sas_keys_exchanged" - case VerificationStateSASMACExchanged: - return "sas_mac" - default: - return fmt.Sprintf("VerificationState(%d)", step) - } -} - -type VerificationTransaction struct { - ExpirationTime jsontime.UnixMilli `json:"expiration_time,omitempty"` - - // RoomID is the room ID if the verification is happening in a room or - // empty if it is a to-device verification. - RoomID id.RoomID `json:"room_id,omitempty"` - - // VerificationState is the current step of the verification flow. - VerificationState VerificationState `json:"verification_state"` - // TransactionID is the ID of the verification transaction. - TransactionID id.VerificationTransactionID `json:"transaction_id"` - - // TheirDeviceID is the device ID of the device that either made the - // initial request or accepted our request. - TheirDeviceID id.DeviceID `json:"their_device_id,omitempty"` - // TheirUserID is the user ID of the other user. - TheirUserID id.UserID `json:"their_user_id,omitempty"` - // TheirSupportedMethods is a list of verification methods that the other - // device supports. - TheirSupportedMethods []event.VerificationMethod `json:"their_supported_methods,omitempty"` - - // SentToDeviceIDs is a list of devices which the initial request was sent - // to. This is only used for to-device verification requests, and is meant - // to be used to send cancellation requests to all other devices when a - // verification request is accepted via a m.key.verification.ready event. - SentToDeviceIDs []id.DeviceID `json:"sent_to_device_ids,omitempty"` - - // QRCodeSharedSecret is the shared secret that was encoded in the QR code - // that we showed. - QRCodeSharedSecret []byte `json:"qr_code_shared_secret,omitempty"` - - StartedByUs bool `json:"started_by_us,omitempty"` // Whether the verification was started by us - StartEventContent *event.VerificationStartEventContent `json:"start_event_content,omitempty"` // The m.key.verification.start event content - Commitment []byte `json:"committment,omitempty"` // The commitment from the m.key.verification.accept event - MACMethod event.MACMethod `json:"mac_method,omitempty"` // The method used to calculate the MAC - EphemeralKey *ECDHPrivateKey `json:"ephemeral_key,omitempty"` // The ephemeral key - EphemeralPublicKeyShared bool `json:"ephemeral_public_key_shared,omitempty"` // Whether this device's ephemeral public key has been shared - OtherPublicKey *ECDHPublicKey `json:"other_public_key,omitempty"` // The other device's ephemeral public key - ReceivedTheirMAC bool `json:"received_their_mac,omitempty"` // Whether we have received their MAC - SentOurMAC bool `json:"sent_our_mac,omitempty"` // Whether we have sent our MAC - ReceivedTheirDone bool `json:"received_their_done,omitempty"` // Whether we have received their done event - SentOurDone bool `json:"sent_our_done,omitempty"` // Whether we have sent our done event -} - -type VerificationStore interface { - // DeleteVerification deletes a verification transaction by ID - DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error - // GetVerificationTransaction gets a verification transaction by ID - GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) - // SaveVerificationTransaction saves a verification transaction by ID - SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error - // FindVerificationTransactionForUserDevice finds a verification - // transaction by user and device ID - FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) - // GetAllVerificationTransactions returns all of the verification - // transactions. This is used to reset the cancellation timeouts. - GetAllVerificationTransactions(ctx context.Context) ([]VerificationTransaction, error) -} - -type InMemoryVerificationStore struct { - txns map[id.VerificationTransactionID]VerificationTransaction -} - -var _ VerificationStore = (*InMemoryVerificationStore)(nil) - -func NewInMemoryVerificationStore() *InMemoryVerificationStore { - return &InMemoryVerificationStore{ - txns: map[id.VerificationTransactionID]VerificationTransaction{}, - } -} - -func (i *InMemoryVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error { - if _, ok := i.txns[txnID]; !ok { - return ErrUnknownVerificationTransaction - } - delete(i.txns, txnID) - return nil -} - -func (i *InMemoryVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) { - if _, ok := i.txns[txnID]; !ok { - return VerificationTransaction{}, ErrUnknownVerificationTransaction - } - return i.txns[txnID], nil -} - -func (i *InMemoryVerificationStore) SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error { - i.txns[txn.TransactionID] = txn - return nil -} - -func (i *InMemoryVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) { - for _, existingTxn := range i.txns { - if existingTxn.TheirUserID == userID && existingTxn.TheirDeviceID == deviceID { - return existingTxn, nil - } - } - return VerificationTransaction{}, ErrUnknownVerificationTransaction -} - -func (i *InMemoryVerificationStore) GetAllVerificationTransactions(ctx context.Context) (txns []VerificationTransaction, err error) { - for _, txn := range i.txns { - txns = append(txns, txn) - } - return -} diff --git a/crypto/verificationhelper/verificationstore_test.go b/crypto/verificationhelper/verificationstore_test.go deleted file mode 100644 index e64153b1..00000000 --- a/crypto/verificationhelper/verificationstore_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package verificationhelper_test - -import ( - "context" - "database/sql" - "errors" - - _ "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/crypto/verificationhelper" - "maunium.net/go/mautrix/id" -) - -type SQLiteVerificationStore struct { - db *sql.DB -} - -const ( - selectVerifications = `SELECT transaction_data FROM verifications` - getVerificationByTransactionID = selectVerifications + ` WHERE transaction_id = ?1` - getVerificationByUserDeviceID = selectVerifications + ` - WHERE transaction_data->>'their_user_id' = ?1 - AND transaction_data->>'their_device_id' = ?2 - ` - deleteVerificationsQuery = `DELETE FROM verifications WHERE transaction_id = ?1` -) - -var _ verificationhelper.VerificationStore = (*SQLiteVerificationStore)(nil) - -func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerificationStore, error) { - _, err := db.ExecContext(ctx, ` - CREATE TABLE verifications ( - transaction_id TEXT PRIMARY KEY NOT NULL, - transaction_data JSONB NOT NULL - ); - CREATE INDEX verifications_user_device_id ON - verifications(transaction_data->>'their_user_id', transaction_data->>'their_device_id'); - `) - return &SQLiteVerificationStore{db}, err -} - -func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) { - rows, err := s.db.QueryContext(ctx, selectVerifications) - return dbutil.NewRowIterWithError(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { - err = rows.Scan(&dbutil.JSON{Data: &txn}) - return - }, err).AsList() -} - -func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) { - zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction") - row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID) - err = row.Scan(&dbutil.JSON{Data: &txn}) - if errors.Is(err, sql.ErrNoRows) { - err = verificationhelper.ErrUnknownVerificationTransaction - } - return -} - -func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) { - row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID) - err = row.Scan(&dbutil.JSON{Data: &txn}) - if errors.Is(err, sql.ErrNoRows) { - err = verificationhelper.ErrUnknownVerificationTransaction - } - return -} - -func (vq *SQLiteVerificationStore) SaveVerificationTransaction(ctx context.Context, txn verificationhelper.VerificationTransaction) (err error) { - zerolog.Ctx(ctx).Debug().Any("transaction", &txn).Msg("Saving verification transaction") - _, err = vq.db.ExecContext(ctx, ` - INSERT INTO verifications (transaction_id, transaction_data) - VALUES (?1, ?2) - ON CONFLICT (transaction_id) DO UPDATE - SET transaction_data=excluded.transaction_data - `, txn.TransactionID, &dbutil.JSON{Data: &txn}) - return -} - -func (vq *SQLiteVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) (err error) { - _, err = vq.db.ExecContext(ctx, deleteVerificationsQuery, txnID) - return -} diff --git a/error.go b/error.go index 4711b3dc..acd90892 100644 --- a/error.go +++ b/error.go @@ -12,8 +12,6 @@ import ( "fmt" "net/http" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/exmaps" "golang.org/x/exp/maps" ) @@ -26,9 +24,6 @@ import ( // // logout // } var ( - // Generic error for when the server encounters an error and it does not have a more specific error code. - // Note that `errors.Is` will check the error message rather than code for M_UNKNOWNs. - MUnknown = RespError{ErrCode: "M_UNKNOWN", StatusCode: http.StatusInternalServerError} // Forbidden access, e.g. joining a room without permission, failed login. MForbidden = RespError{ErrCode: "M_FORBIDDEN", StatusCode: http.StatusForbidden} // Unrecognized request, e.g. the endpoint does not exist or is not implemented. @@ -67,28 +62,11 @@ var ( MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} - // The client specified a room key backup version that is not the current room key backup version for the user. - MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} MConnectionTimeout = RespError{ErrCode: "M_CONNECTION_TIMEOUT"} MConnectionFailed = RespError{ErrCode: "M_CONNECTION_FAILED"} - - MUnredactedContentDeleted = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_DELETED"} - MUnredactedContentNotReceived = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_NOT_RECEIVED"} -) - -var ( - ErrClientIsNil = errors.New("client is nil") - ErrClientHasNoHomeserver = errors.New("client has no homeserver set") - - ErrResponseTooLong = errors.New("response content length too long") - ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") - - // Special error that indicates we should retry canceled contexts. Note that on it's own this - // is useless, the context itself must also be replaced. - ErrContextCancelRetry = errors.New("retry canceled context") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. @@ -114,9 +92,10 @@ func (e HTTPError) Error() string { if e.WrappedError != nil { return fmt.Sprintf("%s: %v", e.Message, e.WrappedError) } else if e.RespError != nil { - return fmt.Sprintf("%s (HTTP %d): %s", e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) + return fmt.Sprintf("failed to %s %s: %s (HTTP %d): %s", e.Request.Method, e.Request.URL.Path, + e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) } else { - msg := fmt.Sprintf("HTTP %d", e.Response.StatusCode) + msg := fmt.Sprintf("failed to %s %s: HTTP %d", e.Request.Method, e.Request.URL.Path, e.Response.StatusCode) if len(e.ResponseBody) > 0 { msg = fmt.Sprintf("%s: %s", msg, e.ResponseBody) } @@ -140,10 +119,7 @@ type RespError struct { Err string ExtraData map[string]any - StatusCode int - ExtraHeader map[string]string - - CanRetry bool + StatusCode int } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -153,70 +129,19 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) - e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool) return nil } func (e *RespError) MarshalJSON() ([]byte, error) { - data := exmaps.NonNilClone(e.ExtraData) + data := maps.Clone(e.ExtraData) + if data == nil { + data = make(map[string]any) + } data["errcode"] = e.ErrCode data["error"] = e.Err - if e.CanRetry { - data["com.beeper.can_retry"] = e.CanRetry - } return json.Marshal(data) } -func (e RespError) Write(w http.ResponseWriter) { - if w == nil { - return - } - statusCode := e.StatusCode - if statusCode == 0 { - statusCode = http.StatusInternalServerError - } - for key, value := range e.ExtraHeader { - w.Header().Set(key, value) - } - exhttp.WriteJSONResponse(w, statusCode, &e) -} - -func (e RespError) WithMessage(msg string, args ...any) RespError { - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - e.Err = msg - return e -} - -func (e RespError) WithStatus(status int) RespError { - e.StatusCode = status - return e -} - -func (e RespError) WithCanRetry(canRetry bool) RespError { - e.CanRetry = canRetry - return e -} - -func (e RespError) WithExtraData(extraData map[string]any) RespError { - e.ExtraData = exmaps.NonNilClone(e.ExtraData) - maps.Copy(e.ExtraData, extraData) - return e -} - -func (e RespError) WithExtraHeader(key, value string) RespError { - e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) - e.ExtraHeader[key] = value - return e -} - -func (e RespError) WithExtraHeaders(headers map[string]string) RespError { - e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) - maps.Copy(e.ExtraHeader, headers) - return e -} - // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err diff --git a/event/accountdata.go b/event/accountdata.go index 223919a1..30ca35a2 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -105,15 +105,3 @@ func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time { } return time.Time{} } - -func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration { - ts := bmec.GetMutedUntilTime() - now := time.Now() - if ts.Before(now) { - return 0 - } else if ts == MutedForever { - return -1 - } else { - return ts.Sub(now) - } -} diff --git a/event/audio.go b/event/audio.go index 9eeb8edb..0fc0818b 100644 --- a/event/audio.go +++ b/event/audio.go @@ -1,21 +1,8 @@ package event -import ( - "encoding/json" -) - type MSC1767Audio struct { - Duration int `json:"duration"` - Waveform []int `json:"waveform"` -} - -type serializableMSC1767Audio MSC1767Audio - -func (ma *MSC1767Audio) MarshalJSON() ([]byte, error) { - if ma.Waveform == nil { - ma.Waveform = []int{} - } - return json.Marshal((*serializableMSC1767Audio)(ma)) + Duration int `json:"duration,omitempty"` + Waveform []int `json:"waveform,omitempty"` } type MSC3245Voice struct{} diff --git a/event/beeper.go b/event/beeper.go index a1a60b35..1394a6ce 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -7,15 +7,6 @@ package event import ( - "encoding/base32" - "encoding/binary" - "encoding/json" - "fmt" - "html" - "regexp" - "strconv" - "strings" - "maunium.net/go/mautrix/id" ) @@ -53,8 +44,6 @@ type BeeperMessageStatusEventContent struct { LastRetry id.EventID `json:"last_retry,omitempty"` - TargetTxnID string `json:"relates_to_txn_id,omitempty"` - MutateEventKey string `json:"mutate_event_key,omitempty"` // Indicates the set of users to whom the event was delivered. If nil, then @@ -64,18 +53,6 @@ type BeeperMessageStatusEventContent struct { DeliveredToUsers *[]id.UserID `json:"delivered_to_users,omitempty"` } -type BeeperRelatesTo struct { - EventID id.EventID `json:"event_id,omitempty"` - RoomID id.RoomID `json:"room_id,omitempty"` - Type RelationType `json:"rel_type,omitempty"` -} - -type BeeperTranscriptionEventContent struct { - Text []ExtensibleText `json:"m.text,omitempty"` - Model string `json:"com.beeper.transcription.model,omitempty"` - RelatesTo BeeperRelatesTo `json:"com.beeper.relates_to,omitempty"` -} - type BeeperRetryMetadata struct { OriginalEventID id.EventID `json:"original_event_id"` RetryCount int `json:"retry_count"` @@ -88,54 +65,18 @@ type BeeperRoomKeyAckEventContent struct { FirstMessageIndex int `json:"first_message_index"` } -type BeeperChatDeleteEventContent struct { - DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` - FromMessageRequest bool `json:"from_message_request,omitempty"` -} - -type BeeperAcceptMessageRequestEventContent struct { - // Whether this was triggered by a message rather than an explicit event - IsImplicit bool `json:"-"` -} - -type BeeperSendStateEventContent struct { - Type string `json:"type"` - StateKey string `json:"state_key"` - Content Content `json:"content"` -} - -type IntOrString int - -func (ios *IntOrString) UnmarshalJSON(data []byte) error { - if len(data) > 0 && data[0] == '"' { - var str string - err := json.Unmarshal(data, &str) - if err != nil { - return err - } - intVal, err := strconv.Atoi(str) - if err != nil { - return err - } - *ios = IntOrString(intVal) - return nil - } - return json.Unmarshal(data, (*int)(ios)) -} - type LinkPreview struct { CanonicalURL string `json:"og:url,omitempty"` Title string `json:"og:title,omitempty"` Type string `json:"og:type,omitempty"` Description string `json:"og:description,omitempty"` - SiteName string `json:"og:site_name,omitempty"` ImageURL id.ContentURIString `json:"og:image,omitempty"` - ImageSize IntOrString `json:"matrix:image:size,omitempty"` - ImageWidth IntOrString `json:"og:image:width,omitempty"` - ImageHeight IntOrString `json:"og:image:height,omitempty"` - ImageType string `json:"og:image:type,omitempty"` + ImageSize int `json:"matrix:image:size,omitempty"` + ImageWidth int `json:"og:image:width,omitempty"` + ImageHeight int `json:"og:image:height,omitempty"` + ImageType string `json:"og:image:type,omitempty"` } // BeeperLinkPreview contains the data for a bundled URL preview as specified in MSC4095 @@ -146,7 +87,6 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` - ImageBlurhash string `json:"matrix:image:blurhash,omitempty"` } type BeeperProfileExtra struct { @@ -157,175 +97,3 @@ type BeeperProfileExtra struct { IsBridgeBot bool `json:"com.beeper.bridge.is_bridge_bot,omitempty"` IsNetworkBot bool `json:"com.beeper.bridge.is_network_bot,omitempty"` } - -type BeeperPerMessageProfile struct { - ID string `json:"id"` - Displayname string `json:"displayname,omitempty"` - AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` - AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` - HasFallback bool `json:"has_fallback,omitempty"` -} - -type BeeperActionMessageType string - -const ( - BeeperActionMessageCall BeeperActionMessageType = "call" -) - -type BeeperActionMessageCallType string - -const ( - BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice" - BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video" -) - -type BeeperActionMessage struct { - Type BeeperActionMessageType `json:"type"` - CallType BeeperActionMessageCallType `json:"call_type,omitempty"` -} - -func (content *MessageEventContent) AddPerMessageProfileFallback() { - if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { - return - } - content.BeeperPerMessageProfile.HasFallback = true - content.EnsureHasHTML() - content.Body = fmt.Sprintf("%s: %s", content.BeeperPerMessageProfile.Displayname, content.Body) - content.FormattedBody = fmt.Sprintf( - "%s: %s", - html.EscapeString(content.BeeperPerMessageProfile.Displayname), - content.FormattedBody, - ) -} - -var HTMLProfileFallbackRegex = regexp.MustCompile(`([^<]+): `) - -func (content *MessageEventContent) RemovePerMessageProfileFallback() { - if content.NewContent != nil && content.NewContent != content { - content.NewContent.RemovePerMessageProfileFallback() - } - if content == nil || content.BeeperPerMessageProfile == nil || !content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { - return - } - content.BeeperPerMessageProfile.HasFallback = false - content.Body = strings.TrimPrefix(content.Body, content.BeeperPerMessageProfile.Displayname+": ") - if content.Format == FormatHTML { - content.FormattedBody = HTMLProfileFallbackRegex.ReplaceAllLiteralString(content.FormattedBody, "") - } -} - -type BeeperAIStreamEventContent struct { - TurnID string `json:"turn_id"` - Seq int `json:"seq"` - Part map[string]any `json:"part"` - TargetEvent id.EventID `json:"target_event,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` -} - -type BeeperEncodedOrder struct { - order int64 - suborder int16 -} - -func NewBeeperEncodedOrder(order int64, suborder int16) *BeeperEncodedOrder { - return &BeeperEncodedOrder{order: order, suborder: suborder} -} - -func BeeperEncodedOrderFromString(str string) (*BeeperEncodedOrder, error) { - order, suborder, err := decodeIntPair(str) - if err != nil { - return nil, err - } - return &BeeperEncodedOrder{order: order, suborder: suborder}, nil -} - -func (b *BeeperEncodedOrder) String() string { - if b == nil { - return "" - } - return encodeIntPair(b.order, b.suborder) -} - -func (b *BeeperEncodedOrder) OrderPair() (int64, int16) { - if b == nil { - return 0, 0 - } - return b.order, b.suborder -} - -func (b *BeeperEncodedOrder) IsZero() bool { - return b == nil || (b.order == 0 && b.suborder == 0) -} - -func (b *BeeperEncodedOrder) MarshalJSON() ([]byte, error) { - return []byte(`"` + b.String() + `"`), nil -} - -func (b *BeeperEncodedOrder) UnmarshalJSON(data []byte) error { - if b == nil { - return fmt.Errorf("BeeperEncodedOrder: receiver is nil") - } - str := string(data) - if len(str) < 2 { - return fmt.Errorf("invalid encoded order string: %s", str) - } - decoded, err := BeeperEncodedOrderFromString(str[1 : len(str)-1]) - if err != nil { - return err - } - b.order, b.suborder = decoded.order, decoded.suborder - return nil -} - -// encodeIntPair encodes an int64 and an int16 into a lexicographically sortable string -func encodeIntPair(a int64, b int16) string { - // Create a buffer to hold the binary representation of the integers. - // Will need 8 bytes for the int64 and 2 bytes for the int16. - var buf [10]byte - - // Flip the sign bit of each integer to map the entire int range to uint - // in a way that preserves the order of the original integers. - // - // Explanation: - // - By XORing with (1 << 63), we flip the most significant bit (sign bit) of the int64 value. - // - Negative numbers (which have a sign bit of 1) become smaller uint64 values. - // - Non-negative numbers (with a sign bit of 0) become larger uint64 values. - // - This mapping preserves the original ordering when the uint64 values are compared. - binary.BigEndian.PutUint64(buf[0:8], uint64(a)^(1<<63)) - binary.BigEndian.PutUint16(buf[8:10], uint16(b)^(1<<15)) - - // Encode the buffer into a Base32 string without padding using the Hex encoding. - // - // Explanation: - // - Base32 encoding converts binary data into a text representation using 32 ASCII characters. - // - Using Base32HexEncoding ensures that the characters are in lexicographical order. - // - Disabling padding results in a consistent string length, which is important for sorting. - encoded := base32.HexEncoding.WithPadding(base32.NoPadding).EncodeToString(buf[:]) - - return encoded -} - -// decodeIntPair decodes a string produced by encodeIntPair back into the original int64 and int16 values -func decodeIntPair(encoded string) (int64, int16, error) { - // Decode the Base32 string back into the original byte buffer. - buf, err := base32.HexEncoding.WithPadding(base32.NoPadding).DecodeString(encoded) - if err != nil { - return 0, 0, fmt.Errorf("failed to decode string: %w", err) - } - - // Check that the decoded buffer has the expected length. - if len(buf) != 10 { - return 0, 0, fmt.Errorf("invalid encoded string length: expected 10 bytes, got %d", len(buf)) - } - - // Read the uint values from the buffer using big-endian byte order. - aPos := binary.BigEndian.Uint64(buf[0:8]) - bPos := binary.BigEndian.Uint16(buf[8:10]) - - // Reverse the sign bit flip to retrieve the original values. - a := int64(aPos ^ (1 << 63)) - b := int16(bPos ^ (1 << 15)) - - return a, b, nil -} diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts deleted file mode 100644 index 26aeb347..00000000 --- a/event/capabilities.d.ts +++ /dev/null @@ -1,225 +0,0 @@ -/** - * The content of the `com.beeper.room_features` state event. - */ -export interface RoomFeatures { - /** - * Supported formatting features. If omitted, no formatting is supported. - * - * Capability level 0 means the corresponding HTML tags/attributes are ignored - * and will be treated as if they don't exist, which means that children will - * be rendered, but attributes will be dropped. - */ - formatting?: Record - /** - * Supported file message types and their features. - * - * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). - */ - file?: Record - /** - * Supported state event types and their parameters. Currently, there are no parameters, - * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.). - * - * Events that are not listed or have a support level of zero or below should be treated as unsupported. - * - * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here. - * `m.room.member` will not be listed here, as it's controlled by the member_actions field. - * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now. - */ - state?: Record - /** - * Supported member actions and their support levels. - * - * Actions that are not listed or have a support level of zero or below should be treated as unsupported. - */ - member_actions?: Record - - /** Maximum length of normal text messages. */ - max_text_length?: integer - - /** Whether location messages (`m.location`) are supported. */ - location_message?: CapabilitySupportLevel - /** Whether polls are supported. */ - poll?: CapabilitySupportLevel - /** Whether replying in a thread is supported. */ - thread?: CapabilitySupportLevel - /** Whether replying to a specific message is supported. */ - reply?: CapabilitySupportLevel - - /** Whether edits are supported. */ - edit?: CapabilitySupportLevel - /** How many times can an individual message be edited. */ - edit_max_count?: integer - /** How old messages can be edited, in seconds. */ - edit_max_age?: seconds - /** Whether deleting messages for everyone is supported */ - delete?: CapabilitySupportLevel - /** How old messages can be deleted for everyone, in seconds. */ - delete_max_age?: seconds - /** Whether deleting messages just for yourself is supported. No message age limit. */ - delete_for_me?: boolean - /** Allowed configuration options for disappearing timers. */ - disappearing_timer?: DisappearingTimerCapability - - /** Whether reactions are supported. */ - reaction?: CapabilitySupportLevel - /** How many reactions can be added to a single message. */ - reaction_count?: integer - /** - * The Unicode emojis allowed for reactions. If omitted, all emojis are allowed. - * Emojis in this list must include variation selector 16 if allowed in the Unicode spec. - */ - allowed_reactions?: string[] - /** Whether custom emoji reactions are allowed. */ - custom_emoji_reactions?: boolean - - /** Whether deleting the chat for yourself is supported. */ - delete_chat?: boolean - /** Whether deleting the chat for all participants is supported. */ - delete_chat_for_everyone?: boolean - /** What can be done with message requests? */ - message_request?: { - accept_with_message?: CapabilitySupportLevel - accept_with_button?: CapabilitySupportLevel - } -} - -declare type integer = number -declare type seconds = integer -declare type milliseconds = integer -declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application" -declare type MIMETypeOrPattern = - "*/*" - | `${MIMEClass}/*` - | `${MIMEClass}/${string}` - | `${MIMEClass}/${string}; ${string}` - -export enum MemberAction { - Ban = "ban", - Kick = "kick", - Leave = "leave", - RevokeInvite = "revoke_invite", - Invite = "invite", -} - -declare type EventType = string - -// This is an object for future extensibility (e.g. max name/topic length) -export interface StateFeatures { - level: CapabilitySupportLevel -} - -export enum CapabilityMsgType { - // Real message types used in the `msgtype` field - Image = "m.image", - File = "m.file", - Audio = "m.audio", - Video = "m.video", - - // Pseudo types only used in capabilities - /** An `m.audio` message that has `"org.matrix.msc3245.voice": {}` */ - Voice = "org.matrix.msc3245.voice", - /** An `m.video` message that has `"info": {"fi.mau.gif": true}`, or an `m.image` message of type `image/gif` */ - GIF = "fi.mau.gif", - /** An `m.sticker` event, no `msgtype` field */ - Sticker = "m.sticker", -} - -export interface FileFeatures { - /** - * The supported MIME types or type patterns and their support levels. - * - * If a mime type doesn't match any pattern provided, - * it should be treated as support level -2 (will be rejected). - */ - mime_types: Record - - /** The support level for captions within this file message type */ - caption?: CapabilitySupportLevel - /** The maximum length for captions (only applicable if captions are supported). */ - max_caption_length?: integer - /** The maximum file size as bytes. */ - max_size?: integer - /** For images and videos, the maximum width as pixels. */ - max_width?: integer - /** For images and videos, the maximum height as pixels. */ - max_height?: integer - /** For videos and audio files, the maximum duration as seconds. */ - max_duration?: seconds - - /** Can this type of file be sent as view-once media? */ - view_once?: boolean -} - -export enum DisappearingType { - None = "", - AfterRead = "after_read", - AfterSend = "after_send", -} - -export interface DisappearingTimerCapability { - types: DisappearingType[] - /** Allowed timer values. If omitted, any timer is allowed. */ - timers?: milliseconds[] - /** - * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear - * - * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't, - * so the hardcoded features for Matrix rooms should set this to true, while bridges will not. - */ - omit_empty_timer?: true -} - -/** - * The support level for a feature. These are integers rather than booleans - * to accurately represent what the bridge is doing and hopefully make the - * state event more generally useful. Our clients should check for > 0 to - * determine if the feature should be allowed. - */ -export enum CapabilitySupportLevel { - /** The feature is unsupported and messages using it will be rejected. */ - Rejected = -2, - /** The feature is unsupported and has no fallback. The message will go through, but data may be lost. */ - Dropped = -1, - /** The feature is unsupported, but may have a fallback. The nature of the fallback depends on the context. */ - Unsupported = 0, - /** The feature is partially supported (e.g. it may be converted to a different format). */ - PartialSupport = 1, - /** The feature is fully supported and can be safely used. */ - FullySupported = 2, -} - -/** - * A formatting feature that consists of specific HTML tags and/or attributes. - */ -export enum FormattingFeature { - Bold = "bold", // strong, b - Italic = "italic", // em, i - Underline = "underline", // u - Strikethrough = "strikethrough", // del, s - InlineCode = "inline_code", // code - CodeBlock = "code_block", // pre + code - SyntaxHighlighting = "code_block.syntax_highlighting", //

-	Blockquote = "blockquote", // blockquote
-	InlineLink = "inline_link", // a
-	UserLink = "user_link", // 
-	RoomLink = "room_link", // 
-	EventLink = "event_link", // 
-	AtRoomMention = "at_room_mention", // @room (no html tag)
-	UnorderedList = "unordered_list", // ul + li
-	OrderedList = "ordered_list", // ol + li
-	ListStart = "ordered_list.start", // 
    - ListJumpValue = "ordered_list.jump_value", //
  1. - CustomEmoji = "custom_emoji", // - Spoiler = "spoiler", // - SpoilerReason = "spoiler.reason", // - TextForegroundColor = "color.foreground", // - TextBackgroundColor = "color.background", // - HorizontalLine = "horizontal_line", // hr - Headers = "headers", // h1, h2, h3, h4, h5, h6 - Superscript = "superscript", // sup - Subscript = "subscript", // sub - Math = "math", // - DetailsSummary = "details_summary", //
    ......
    - Table = "table", // table, thead, tbody, tr, th, td -} diff --git a/event/capabilities.go b/event/capabilities.go deleted file mode 100644 index a86c726b..00000000 --- a/event/capabilities.go +++ /dev/null @@ -1,414 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package event - -import ( - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "fmt" - "io" - "mime" - "slices" - "strings" - - "go.mau.fi/util/exerrors" - "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" -) - -type RoomFeatures struct { - ID string `json:"id,omitempty"` - - // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. - - Formatting FormattingFeatureMap `json:"formatting,omitempty"` - File FileFeatureMap `json:"file,omitempty"` - State StateFeatureMap `json:"state,omitempty"` - MemberActions MemberFeatureMap `json:"member_actions,omitempty"` - - MaxTextLength int `json:"max_text_length,omitempty"` - - LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"` - Poll CapabilitySupportLevel `json:"poll,omitempty"` - Thread CapabilitySupportLevel `json:"thread,omitempty"` - Reply CapabilitySupportLevel `json:"reply,omitempty"` - - Edit CapabilitySupportLevel `json:"edit,omitempty"` - EditMaxCount int `json:"edit_max_count,omitempty"` - EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"` - Delete CapabilitySupportLevel `json:"delete,omitempty"` - DeleteForMe bool `json:"delete_for_me,omitempty"` - DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"` - - DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"` - - Reaction CapabilitySupportLevel `json:"reaction,omitempty"` - ReactionCount int `json:"reaction_count,omitempty"` - AllowedReactions []string `json:"allowed_reactions,omitempty"` - CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"` - - ReadReceipts bool `json:"read_receipts,omitempty"` - TypingNotifications bool `json:"typing_notifications,omitempty"` - Archive bool `json:"archive,omitempty"` - MarkAsUnread bool `json:"mark_as_unread,omitempty"` - DeleteChat bool `json:"delete_chat,omitempty"` - DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` - - MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"` - - PerMessageProfileRelay bool `json:"-"` -} - -func (rf *RoomFeatures) GetID() string { - if rf.ID != "" { - return rf.ID - } - return base64.RawURLEncoding.EncodeToString(rf.Hash()) -} - -func (rf *RoomFeatures) Clone() *RoomFeatures { - if rf == nil { - return nil - } - clone := *rf - clone.File = clone.File.Clone() - clone.Formatting = maps.Clone(clone.Formatting) - clone.State = clone.State.Clone() - clone.MemberActions = clone.MemberActions.Clone() - clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) - clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) - clone.DisappearingTimer = clone.DisappearingTimer.Clone() - clone.AllowedReactions = slices.Clone(clone.AllowedReactions) - clone.MessageRequest = clone.MessageRequest.Clone() - return &clone -} - -type MemberFeatureMap map[MemberAction]CapabilitySupportLevel - -func (mfm MemberFeatureMap) Clone() MemberFeatureMap { - return maps.Clone(mfm) -} - -type MemberAction string - -const ( - MemberActionBan MemberAction = "ban" - MemberActionKick MemberAction = "kick" - MemberActionLeave MemberAction = "leave" - MemberActionRevokeInvite MemberAction = "revoke_invite" - MemberActionInvite MemberAction = "invite" -) - -type StateFeatureMap map[string]*StateFeatures - -func (sfm StateFeatureMap) Clone() StateFeatureMap { - dup := maps.Clone(sfm) - for key, value := range dup { - dup[key] = value.Clone() - } - return dup -} - -type StateFeatures struct { - Level CapabilitySupportLevel `json:"level"` -} - -func (sf *StateFeatures) Clone() *StateFeatures { - if sf == nil { - return nil - } - clone := *sf - return &clone -} - -func (sf *StateFeatures) Hash() []byte { - return sf.Level.Hash() -} - -type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel - -type FileFeatureMap map[CapabilityMsgType]*FileFeatures - -func (ffm FileFeatureMap) Clone() FileFeatureMap { - dup := maps.Clone(ffm) - for key, value := range dup { - dup[key] = value.Clone() - } - return dup -} - -type DisappearingTimerCapability struct { - Types []DisappearingType `json:"types"` - Timers []jsontime.Milliseconds `json:"timers,omitempty"` - - OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` -} - -func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability { - if dtc == nil { - return nil - } - clone := *dtc - clone.Types = slices.Clone(clone.Types) - clone.Timers = slices.Clone(clone.Timers) - return &clone -} - -func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { - if dtc == nil || content == nil || content.Type == DisappearingTypeNone { - return true - } - return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) -} - -type MessageRequestFeatures struct { - AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"` - AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"` -} - -func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures { - return ptr.Clone(mrf) -} - -func (mrf *MessageRequestFeatures) Hash() []byte { - if mrf == nil { - return nil - } - hasher := sha256.New() - hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage) - hashValue(hasher, "accept_with_button", mrf.AcceptWithButton) - return hasher.Sum(nil) -} - -type CapabilityMsgType = MessageType - -// Message types which are used for event capability signaling, but aren't real values for the msgtype field. -const ( - CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice" - CapMsgGIF CapabilityMsgType = "fi.mau.gif" - CapMsgSticker CapabilityMsgType = "m.sticker" -) - -type CapabilitySupportLevel int - -func (csl CapabilitySupportLevel) Partial() bool { - return csl >= CapLevelPartialSupport -} - -func (csl CapabilitySupportLevel) Full() bool { - return csl >= CapLevelFullySupported -} - -func (csl CapabilitySupportLevel) Reject() bool { - return csl <= CapLevelRejected -} - -const ( - CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected. - CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost. - CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback. - CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format). - CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used. -) - -type FormattingFeature string - -const ( - FmtBold FormattingFeature = "bold" // strong, b - FmtItalic FormattingFeature = "italic" // em, i - FmtUnderline FormattingFeature = "underline" // u - FmtStrikethrough FormattingFeature = "strikethrough" // del, s - FmtInlineCode FormattingFeature = "inline_code" // code - FmtCodeBlock FormattingFeature = "code_block" // pre + code - FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
    
    -	FmtBlockquote          FormattingFeature = "blockquote"                     // blockquote
    -	FmtInlineLink          FormattingFeature = "inline_link"                    // a
    -	FmtUserLink            FormattingFeature = "user_link"                      // 
    -	FmtRoomLink            FormattingFeature = "room_link"                      // 
    -	FmtEventLink           FormattingFeature = "event_link"                     // 
    -	FmtAtRoomMention       FormattingFeature = "at_room_mention"                // @room (no html tag)
    -	FmtUnorderedList       FormattingFeature = "unordered_list"                 // ul + li
    -	FmtOrderedList         FormattingFeature = "ordered_list"                   // ol + li
    -	FmtListStart           FormattingFeature = "ordered_list.start"             // 
      - FmtListJumpValue FormattingFeature = "ordered_list.jump_value" //
    1. - FmtCustomEmoji FormattingFeature = "custom_emoji" // - FmtSpoiler FormattingFeature = "spoiler" // - FmtSpoilerReason FormattingFeature = "spoiler.reason" // - FmtTextForegroundColor FormattingFeature = "color.foreground" // - FmtTextBackgroundColor FormattingFeature = "color.background" // - FmtHorizontalLine FormattingFeature = "horizontal_line" // hr - FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6 - FmtSuperscript FormattingFeature = "superscript" // sup - FmtSubscript FormattingFeature = "subscript" // sub - FmtMath FormattingFeature = "math" // - FmtDetailsSummary FormattingFeature = "details_summary" //
      ......
      - FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td -) - -type FileFeatures struct { - // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. - - MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"` - - Caption CapabilitySupportLevel `json:"caption,omitempty"` - MaxCaptionLength int `json:"max_caption_length,omitempty"` - - MaxSize int64 `json:"max_size,omitempty"` - MaxWidth int `json:"max_width,omitempty"` - MaxHeight int `json:"max_height,omitempty"` - MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"` - - ViewOnce bool `json:"view_once,omitempty"` -} - -func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel { - match, ok := ff.MimeTypes[inputType] - if ok { - return match - } - if strings.IndexByte(inputType, ';') != -1 { - plainMime, _, _ := mime.ParseMediaType(inputType) - if plainMime != "" { - if match, ok = ff.MimeTypes[plainMime]; ok { - return match - } - } - } - if slash := strings.IndexByte(inputType, '/'); slash > 0 { - generalType := fmt.Sprintf("%s/*", inputType[:slash]) - if match, ok = ff.MimeTypes[generalType]; ok { - return match - } - } - match, ok = ff.MimeTypes["*/*"] - if ok { - return match - } - return CapLevelRejected -} - -type hashable interface { - Hash() []byte -} - -func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) { - keys := maps.Keys(data) - slices.Sort(keys) - exerrors.Must(w.Write([]byte(name))) - for _, key := range keys { - exerrors.Must(w.Write([]byte(key))) - exerrors.Must(w.Write(data[key].Hash())) - exerrors.Must(w.Write([]byte{0})) - } -} - -func hashValue(w io.Writer, name string, data hashable) { - exerrors.Must(w.Write([]byte(name))) - exerrors.Must(w.Write(data.Hash())) -} - -func hashInt[T constraints.Integer](w io.Writer, name string, data T) { - exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data)))) -} - -func hashBool[T ~bool](w io.Writer, name string, data T) { - exerrors.Must(w.Write([]byte(name))) - if data { - exerrors.Must(w.Write([]byte{1})) - } else { - exerrors.Must(w.Write([]byte{0})) - } -} - -func (csl CapabilitySupportLevel) Hash() []byte { - return []byte{byte(csl + 128)} -} - -func (rf *RoomFeatures) Hash() []byte { - hasher := sha256.New() - - hashMap(hasher, "formatting", rf.Formatting) - hashMap(hasher, "file", rf.File) - hashMap(hasher, "state", rf.State) - hashMap(hasher, "member_actions", rf.MemberActions) - - hashInt(hasher, "max_text_length", rf.MaxTextLength) - - hashValue(hasher, "location_message", rf.LocationMessage) - hashValue(hasher, "poll", rf.Poll) - hashValue(hasher, "thread", rf.Thread) - hashValue(hasher, "reply", rf.Reply) - - hashValue(hasher, "edit", rf.Edit) - hashInt(hasher, "edit_max_count", rf.EditMaxCount) - hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get()) - - hashValue(hasher, "delete", rf.Delete) - hashBool(hasher, "delete_for_me", rf.DeleteForMe) - hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get()) - hashValue(hasher, "disappearing_timer", rf.DisappearingTimer) - - hashValue(hasher, "reaction", rf.Reaction) - hashInt(hasher, "reaction_count", rf.ReactionCount) - hasher.Write([]byte("allowed_reactions")) - for _, reaction := range rf.AllowedReactions { - hasher.Write([]byte(reaction)) - } - hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions) - - hashBool(hasher, "read_receipts", rf.ReadReceipts) - hashBool(hasher, "typing_notifications", rf.TypingNotifications) - hashBool(hasher, "archive", rf.Archive) - hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) - hashBool(hasher, "delete_chat", rf.DeleteChat) - hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) - hashValue(hasher, "message_request", rf.MessageRequest) - - return hasher.Sum(nil) -} - -func (dtc *DisappearingTimerCapability) Hash() []byte { - if dtc == nil { - return nil - } - hasher := sha256.New() - hasher.Write([]byte("types")) - for _, t := range dtc.Types { - hasher.Write([]byte(t)) - } - hasher.Write([]byte("timers")) - for _, timer := range dtc.Timers { - hashInt(hasher, "", timer.Milliseconds()) - } - return hasher.Sum(nil) -} - -func (ff *FileFeatures) Hash() []byte { - hasher := sha256.New() - hashMap(hasher, "mime_types", ff.MimeTypes) - hashValue(hasher, "caption", ff.Caption) - hashInt(hasher, "max_caption_length", ff.MaxCaptionLength) - hashInt(hasher, "max_size", ff.MaxSize) - hashInt(hasher, "max_width", ff.MaxWidth) - hashInt(hasher, "max_height", ff.MaxHeight) - hashInt(hasher, "max_duration", ff.MaxDuration.Get()) - hashBool(hasher, "view_once", ff.ViewOnce) - return hasher.Sum(nil) -} - -func (ff *FileFeatures) Clone() *FileFeatures { - if ff == nil { - return nil - } - clone := *ff - clone.MimeTypes = maps.Clone(clone.MimeTypes) - clone.MaxDuration = ptr.Clone(clone.MaxDuration) - return &clone -} diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go deleted file mode 100644 index ce07c4c0..00000000 --- a/event/cmdschema/content.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "crypto/sha256" - "encoding/base64" - "fmt" - "reflect" - "slices" - - "go.mau.fi/util/exsync" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type EventContent struct { - Command string `json:"command"` - Aliases []string `json:"aliases,omitempty"` - Parameters []*Parameter `json:"parameters,omitempty"` - Description *event.ExtensibleTextContainer `json:"description,omitempty"` - TailParam string `json:"fi.mau.tail_parameter,omitempty"` -} - -func (ec *EventContent) Validate() error { - if ec == nil { - return fmt.Errorf("event content is nil") - } else if ec.Command == "" { - return fmt.Errorf("command is empty") - } - var tailFound bool - dupMap := exsync.NewSet[string]() - for i, p := range ec.Parameters { - if err := p.Validate(); err != nil { - return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) - } else if !dupMap.Add(p.Key) { - return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) - } else if p.Key == ec.TailParam { - tailFound = true - } else if tailFound && !p.Optional { - return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) - } - } - if ec.TailParam != "" && !tailFound { - return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) - } - return nil -} - -func (ec *EventContent) IsValid() bool { - return ec.Validate() == nil -} - -func (ec *EventContent) StateKey(owner id.UserID) string { - hash := sha256.Sum256([]byte(ec.Command + owner.String())) - return base64.StdEncoding.EncodeToString(hash[:]) -} - -func (ec *EventContent) Equals(other *EventContent) bool { - if ec == nil || other == nil { - return ec == other - } - return ec.Command == other.Command && - slices.Equal(ec.Aliases, other.Aliases) && - slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && - ec.Description.Equals(other.Description) && - ec.TailParam == other.TailParam -} - -func init() { - event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{}) -} diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go deleted file mode 100644 index 4193b297..00000000 --- a/event/cmdschema/parameter.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "fmt" - "slices" - - "go.mau.fi/util/exslices" - - "maunium.net/go/mautrix/event" -) - -type Parameter struct { - Key string `json:"key"` - Schema *ParameterSchema `json:"schema"` - Optional bool `json:"optional,omitempty"` - Description *event.ExtensibleTextContainer `json:"description,omitempty"` - DefaultValue any `json:"fi.mau.default_value,omitempty"` -} - -func (p *Parameter) Equals(other *Parameter) bool { - if p == nil || other == nil { - return p == other - } - return p.Key == other.Key && - p.Schema.Equals(other.Schema) && - p.Optional == other.Optional && - p.Description.Equals(other.Description) && - p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values -} - -func (p *Parameter) Validate() error { - if p == nil { - return fmt.Errorf("parameter is nil") - } else if p.Key == "" { - return fmt.Errorf("key is empty") - } - return p.Schema.Validate() -} - -func (p *Parameter) IsValid() bool { - return p.Validate() == nil -} - -func (p *Parameter) GetDefaultValue() any { - if p != nil && p.DefaultValue != nil { - return p.DefaultValue - } else if p == nil || p.Optional { - return nil - } - return p.Schema.GetDefaultValue() -} - -type PrimitiveType string - -const ( - PrimitiveTypeString PrimitiveType = "string" - PrimitiveTypeInteger PrimitiveType = "integer" - PrimitiveTypeBoolean PrimitiveType = "boolean" - PrimitiveTypeServerName PrimitiveType = "server_name" - PrimitiveTypeUserID PrimitiveType = "user_id" - PrimitiveTypeRoomID PrimitiveType = "room_id" - PrimitiveTypeRoomAlias PrimitiveType = "room_alias" - PrimitiveTypeEventID PrimitiveType = "event_id" -) - -func (pt PrimitiveType) Schema() *ParameterSchema { - return &ParameterSchema{ - SchemaType: SchemaTypePrimitive, - Type: pt, - } -} - -func (pt PrimitiveType) IsValid() bool { - switch pt { - case PrimitiveTypeString, - PrimitiveTypeInteger, - PrimitiveTypeBoolean, - PrimitiveTypeServerName, - PrimitiveTypeUserID, - PrimitiveTypeRoomID, - PrimitiveTypeRoomAlias, - PrimitiveTypeEventID: - return true - default: - return false - } -} - -type SchemaType string - -const ( - SchemaTypePrimitive SchemaType = "primitive" - SchemaTypeArray SchemaType = "array" - SchemaTypeUnion SchemaType = "union" - SchemaTypeLiteral SchemaType = "literal" -) - -type ParameterSchema struct { - SchemaType SchemaType `json:"schema_type"` - Type PrimitiveType `json:"type,omitempty"` // Only for primitive - Items *ParameterSchema `json:"items,omitempty"` // Only for array - Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union - Value any `json:"value,omitempty"` // Only for literal -} - -func Literal(value any) *ParameterSchema { - return &ParameterSchema{ - SchemaType: SchemaTypeLiteral, - Value: value, - } -} - -func Enum(values ...any) *ParameterSchema { - return Union(exslices.CastFunc(values, Literal)...) -} - -func flattenUnion(variants []*ParameterSchema) []*ParameterSchema { - var flattened []*ParameterSchema - for _, variant := range variants { - switch variant.SchemaType { - case SchemaTypeArray: - panic(fmt.Errorf("illegal array schema in union")) - case SchemaTypeUnion: - flattened = append(flattened, flattenUnion(variant.Variants)...) - default: - flattened = append(flattened, variant) - } - } - return flattened -} - -func Union(variants ...*ParameterSchema) *ParameterSchema { - needsFlattening := false - for _, variant := range variants { - if variant.SchemaType == SchemaTypeArray { - panic(fmt.Errorf("illegal array schema in union")) - } else if variant.SchemaType == SchemaTypeUnion { - needsFlattening = true - } - } - if needsFlattening { - variants = flattenUnion(variants) - } - return &ParameterSchema{ - SchemaType: SchemaTypeUnion, - Variants: variants, - } -} - -func Array(items *ParameterSchema) *ParameterSchema { - if items.SchemaType == SchemaTypeArray { - panic(fmt.Errorf("illegal array schema in array")) - } - return &ParameterSchema{ - SchemaType: SchemaTypeArray, - Items: items, - } -} - -func (ps *ParameterSchema) GetDefaultValue() any { - if ps == nil { - return nil - } - switch ps.SchemaType { - case SchemaTypePrimitive: - switch ps.Type { - case PrimitiveTypeInteger: - return 0 - case PrimitiveTypeBoolean: - return false - default: - return "" - } - case SchemaTypeArray: - return []any{} - case SchemaTypeUnion: - if len(ps.Variants) > 0 { - return ps.Variants[0].GetDefaultValue() - } - return nil - case SchemaTypeLiteral: - return ps.Value - default: - return nil - } -} - -func (ps *ParameterSchema) IsValid() bool { - return ps.validate("") == nil -} - -func (ps *ParameterSchema) Validate() error { - return ps.validate("") -} - -func (ps *ParameterSchema) validate(parent SchemaType) error { - if ps == nil { - return fmt.Errorf("schema is nil") - } - switch ps.SchemaType { - case SchemaTypePrimitive: - if !ps.Type.IsValid() { - return fmt.Errorf("invalid primitive type %s", ps.Type) - } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil { - return fmt.Errorf("primitive schema has extra fields") - } - return nil - case SchemaTypeArray: - if parent != "" { - return fmt.Errorf("arrays can't be nested in other types") - } else if err := ps.Items.validate(ps.SchemaType); err != nil { - return fmt.Errorf("item schema is invalid: %w", err) - } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil { - return fmt.Errorf("array schema has extra fields") - } - return nil - case SchemaTypeUnion: - if len(ps.Variants) == 0 { - return fmt.Errorf("no variants specified for union") - } else if parent != "" && parent != SchemaTypeArray { - return fmt.Errorf("unions can't be nested in anything other than arrays") - } - for i, v := range ps.Variants { - if err := v.validate(ps.SchemaType); err != nil { - return fmt.Errorf("variant #%d is invalid: %w", i+1, err) - } - } - if ps.Type != "" || ps.Items != nil || ps.Value != nil { - return fmt.Errorf("union schema has extra fields") - } - return nil - case SchemaTypeLiteral: - switch typedVal := ps.Value.(type) { - case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue: - // ok - case map[string]any: - if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" { - return fmt.Errorf("literal value has invalid map data") - } - default: - return fmt.Errorf("literal value has unsupported type %T", ps.Value) - } - if ps.Type != "" || ps.Items != nil || ps.Variants != nil { - return fmt.Errorf("literal schema has extra fields") - } - return nil - default: - return fmt.Errorf("invalid schema type %s", ps.SchemaType) - } -} - -func (ps *ParameterSchema) Equals(other *ParameterSchema) bool { - if ps == nil || other == nil { - return ps == other - } - return ps.SchemaType == other.SchemaType && - ps.Type == other.Type && - ps.Items.Equals(other.Items) && - slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) && - ps.Value == other.Value // TODO this won't work for room/event ID values -} - -func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool { - switch ps.SchemaType { - case SchemaTypePrimitive: - return ps.Type == prim - case SchemaTypeUnion: - for _, variant := range ps.Variants { - if variant.AllowsPrimitive(prim) { - return true - } - } - return false - case SchemaTypeArray: - return ps.Items.AllowsPrimitive(prim) - default: - return false - } -} diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go deleted file mode 100644 index 92e69b60..00000000 --- a/event/cmdschema/parse.go +++ /dev/null @@ -1,478 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const botArrayOpener = "<" -const botArrayCloser = ">" - -func parseQuoted(val string) (parsed, remaining string, quoted bool) { - if len(val) == 0 { - return - } - if !strings.HasPrefix(val, `"`) { - spaceIdx := strings.IndexByte(val, ' ') - if spaceIdx == -1 { - parsed = val - } else { - parsed = val[:spaceIdx] - remaining = strings.TrimLeft(val[spaceIdx+1:], " ") - } - return - } - val = val[1:] - var buf strings.Builder - for { - quoteIdx := strings.IndexByte(val, '"') - var valUntilQuote string - if quoteIdx == -1 { - valUntilQuote = val - } else { - valUntilQuote = val[:quoteIdx] - } - escapeIdx := strings.IndexByte(valUntilQuote, '\\') - if escapeIdx >= 0 { - buf.WriteString(val[:escapeIdx]) - if len(val) > escapeIdx+1 { - buf.WriteByte(val[escapeIdx+1]) - } - val = val[min(escapeIdx+2, len(val)):] - } else if quoteIdx >= 0 { - buf.WriteString(val[:quoteIdx]) - val = val[quoteIdx+1:] - break - } else if buf.Len() == 0 { - // Unterminated quote, no escape characters, val is the whole input - return val, "", true - } else { - // Unterminated quote, but there were escape characters previously - buf.WriteString(val) - val = "" - break - } - } - return buf.String(), strings.TrimLeft(val, " "), true -} - -// ParseInput tries to parse the given text into a bot command event matching this command definition. -// -// If the prefix doesn't match, this will return a nil content and nil error. -// If the prefix does match, some content is always returned, but there may still be an error if parsing failed. -func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) { - prefix := ec.parsePrefix(input, sigils, owner.String()) - if prefix == "" { - return nil, nil - } - content = &event.MessageEventContent{ - MsgType: event.MsgText, - Body: input, - Mentions: &event.Mentions{UserIDs: []id.UserID{owner}}, - MSC4391BotCommand: &event.MSC4391BotCommandInput{ - Command: ec.Command, - }, - } - content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):]) - return content, err -} - -func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { - args := make(map[string]any) - var retErr error - setError := func(err error) { - if err != nil && retErr == nil { - retErr = err - } - } - processParameter := func(param *Parameter, isLast, isTail, isNamed bool) { - origInput := input - var nextVal string - var wasQuoted bool - if param.Schema.SchemaType == SchemaTypeArray { - hasOpener := strings.HasPrefix(input, botArrayOpener) - arrayClosed := false - if hasOpener { - input = input[len(botArrayOpener):] - if strings.HasPrefix(input, botArrayCloser) { - input = strings.TrimLeft(input[len(botArrayCloser):], " ") - arrayClosed = true - } - } - var collector []any - for len(input) > 0 && !arrayClosed { - //origInput = input - nextVal, input, wasQuoted = parseQuoted(input) - if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) { - // The value wasn't quoted and has the array delimiter at the end, close the array - nextVal = strings.TrimRight(nextVal, botArrayCloser) - arrayClosed = true - } else if hasOpener && strings.HasPrefix(input, botArrayCloser) { - // The value was quoted or there was a space, and the next character is the - // array delimiter, close the array - input = strings.TrimLeft(input[len(botArrayCloser):], " ") - arrayClosed = true - } else if !hasOpener && !isLast { - // For array arguments in the middle without the <> delimiters, stop after the first item - arrayClosed = true - } - parsedVal, err := param.Schema.Items.ParseString(nextVal) - if err == nil { - collector = append(collector, parsedVal) - } else if hasOpener || isLast { - setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err)) - } else { - //input = origInput - } - } - args[param.Key] = collector - } else { - nextVal, input, wasQuoted = parseQuoted(input) - if (isLast || isTail) && !wasQuoted && len(input) > 0 { - // If the last argument is not quoted, just treat the rest of the string - // as the argument without escapes (arguments with escapes should be quoted). - nextVal += " " + input - input = "" - } - // Special case for named boolean parameters: if no value is given, treat it as true - if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) { - args[param.Key] = true - return - } - if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { - setError(fmt.Errorf("missing value for required parameter %s", param.Key)) - } - parsedVal, err := param.Schema.ParseString(nextVal) - if err != nil { - args[param.Key] = param.GetDefaultValue() - // For optional parameters that fail to parse, restore the input and try passing it as the next parameter - if param.Optional && !isLast && !isNamed { - input = strings.TrimLeft(origInput, " ") - } else if !param.Optional || isNamed { - setError(fmt.Errorf("failed to parse %s: %w", param.Key, err)) - } - } else { - args[param.Key] = parsedVal - } - } - } - skipParams := make([]bool, len(ec.Parameters)) - for i, param := range ec.Parameters { - for strings.HasPrefix(input, "--") { - nameEndIdx := strings.IndexAny(input, " =") - if nameEndIdx == -1 { - nameEndIdx = len(input) - } - overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx]) - if overrideParam != nil { - // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input - input = strings.TrimPrefix(input[nameEndIdx:], "=") - skipParams[paramIdx] = true - processParameter(overrideParam, false, false, true) - } else { - break - } - } - isTail := param.Key == ec.TailParam - if skipParams[i] || (param.Optional && !isTail) { - continue - } - processParameter(param, i == len(ec.Parameters)-1, isTail, false) - } - jsonArgs, marshalErr := json.Marshal(args) - if marshalErr != nil { - return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr) - } - return jsonArgs, retErr -} - -func (ec *EventContent) parameterByName(name string) (*Parameter, int) { - for i, param := range ec.Parameters { - if strings.EqualFold(param.Key, name) { - return param, i - } - } - return nil, -1 -} - -func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) { - input := origInput - var chosenSigil string - for _, sigil := range sigils { - if strings.HasPrefix(input, sigil) { - chosenSigil = sigil - break - } - } - if chosenSigil == "" { - return "" - } - input = input[len(chosenSigil):] - var chosenAlias string - if !strings.HasPrefix(input, ec.Command) { - for _, alias := range ec.Aliases { - if strings.HasPrefix(input, alias) { - chosenAlias = alias - break - } - } - if chosenAlias == "" { - return "" - } - } else { - chosenAlias = ec.Command - } - input = strings.TrimPrefix(input[len(chosenAlias):], owner) - if input == "" || input[0] == ' ' { - input = strings.TrimLeft(input, " ") - return origInput[:len(origInput)-len(input)] - } - return "" -} - -func (pt PrimitiveType) ValidateValue(value any) bool { - _, err := pt.NormalizeValue(value) - return err == nil -} - -func normalizeNumber(value any) (int, error) { - switch typedValue := value.(type) { - case int: - return typedValue, nil - case int64: - return int(typedValue), nil - case float64: - return int(typedValue), nil - case json.Number: - if i, err := typedValue.Int64(); err != nil { - return 0, fmt.Errorf("failed to parse json.Number: %w", err) - } else { - return int(i), nil - } - default: - return 0, fmt.Errorf("unsupported type %T for integer", value) - } -} - -func (pt PrimitiveType) NormalizeValue(value any) (any, error) { - switch pt { - case PrimitiveTypeInteger: - return normalizeNumber(value) - case PrimitiveTypeBoolean: - bv, ok := value.(bool) - if !ok { - return nil, fmt.Errorf("unsupported type %T for boolean", value) - } - return bv, nil - case PrimitiveTypeString, PrimitiveTypeServerName: - str, ok := value.(string) - if !ok { - return nil, fmt.Errorf("unsupported type %T for string", value) - } - return str, pt.validateStringValue(str) - case PrimitiveTypeUserID, PrimitiveTypeRoomAlias: - str, ok := value.(string) - if !ok { - return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value) - } else if plainErr := pt.validateStringValue(str); plainErr == nil { - return str, nil - } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil { - return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err) - } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID { - return parsed.UserID(), nil - } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias { - return parsed.RoomAlias(), nil - } else { - return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1) - } - case PrimitiveTypeRoomID, PrimitiveTypeEventID: - riv, err := NormalizeRoomIDValue(value) - if err != nil { - return nil, err - } - return riv, riv.Validate() - default: - return nil, fmt.Errorf("cannot normalize value for argument type %s", pt) - } -} - -func (pt PrimitiveType) validateStringValue(value string) error { - switch pt { - case PrimitiveTypeString: - return nil - case PrimitiveTypeServerName: - if !id.ValidateServerName(value) { - return fmt.Errorf("invalid server name: %q", value) - } - return nil - case PrimitiveTypeUserID: - _, _, err := id.UserID(value).ParseAndValidateRelaxed() - return err - case PrimitiveTypeRoomAlias: - sigil, localpart, serverName := id.ParseCommonIdentifier(value) - if sigil != '#' || localpart == "" || serverName == "" { - return fmt.Errorf("invalid room alias: %q", value) - } else if !id.ValidateServerName(serverName) { - return fmt.Errorf("invalid server name in room alias: %q", serverName) - } - return nil - default: - panic(fmt.Errorf("validateStringValue called with invalid type %s", pt)) - } -} - -func parseBoolean(val string) (bool, error) { - if len(val) == 0 { - return false, fmt.Errorf("cannot parse empty string as boolean") - } - switch strings.ToLower(val) { - case "t", "true", "y", "yes", "1": - return true, nil - case "f", "false", "n", "no", "0": - return false, nil - default: - return false, fmt.Errorf("invalid boolean string: %q", val) - } -} - -var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`) - -func parseRoomOrEventID(value string) (*RoomIDValue, error) { - if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") { - matches := markdownLinkRegex.FindStringSubmatch(value) - if len(matches) == 2 { - value = matches[1] - } - } - parsed, err := id.ParseMatrixURIOrMatrixToURL(value) - if err != nil && strings.HasPrefix(value, "!") { - return &RoomIDValue{ - Type: PrimitiveTypeRoomID, - RoomID: id.RoomID(value), - }, nil - } - if err != nil { - return nil, err - } else if parsed.Sigil1 != '!' { - return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1) - } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' { - return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2) - } - valType := PrimitiveTypeRoomID - if parsed.MXID2 != "" { - valType = PrimitiveTypeEventID - } - return &RoomIDValue{ - Type: valType, - RoomID: parsed.RoomID(), - Via: parsed.Via, - EventID: parsed.EventID(), - }, nil -} - -func (pt PrimitiveType) ParseString(value string) (any, error) { - switch pt { - case PrimitiveTypeInteger: - return strconv.Atoi(value) - case PrimitiveTypeBoolean: - return parseBoolean(value) - case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID: - return value, pt.validateStringValue(value) - case PrimitiveTypeRoomAlias: - plainErr := pt.validateStringValue(value) - if plainErr == nil { - return value, nil - } - parsed, err := id.ParseMatrixURIOrMatrixToURL(value) - if err != nil { - return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err) - } else if parsed.Sigil1 != '#' { - return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1) - } - return parsed.RoomAlias(), nil - case PrimitiveTypeRoomID, PrimitiveTypeEventID: - parsed, err := parseRoomOrEventID(value) - if err != nil { - return nil, err - } else if pt != parsed.Type { - return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type) - } - return parsed, nil - default: - return nil, fmt.Errorf("cannot parse string for argument type %s", pt) - } -} - -func (ps *ParameterSchema) ParseString(value string) (any, error) { - if ps == nil { - return nil, fmt.Errorf("parameter schema is nil") - } - switch ps.SchemaType { - case SchemaTypePrimitive: - return ps.Type.ParseString(value) - case SchemaTypeLiteral: - switch typedValue := ps.Value.(type) { - case string: - if value == typedValue { - return typedValue, nil - } else { - return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value) - } - case int, int64, float64, json.Number: - expectedVal, _ := normalizeNumber(typedValue) - intVal, err := strconv.Atoi(value) - if err != nil { - return nil, fmt.Errorf("failed to parse integer literal: %w", err) - } else if intVal != expectedVal { - return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal) - } - return intVal, nil - case bool: - boolVal, err := parseBoolean(value) - if err != nil { - return nil, fmt.Errorf("failed to parse boolean literal: %w", err) - } else if boolVal != typedValue { - return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal) - } - return boolVal, nil - case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage: - expectedVal, _ := NormalizeRoomIDValue(typedValue) - parsed, err := parseRoomOrEventID(value) - if err != nil { - return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err) - } else if !parsed.Equals(expectedVal) { - return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed) - } - return parsed, nil - default: - return nil, fmt.Errorf("unsupported literal type %T", ps.Value) - } - case SchemaTypeUnion: - var errs []error - for _, variant := range ps.Variants { - if parsed, err := variant.ParseString(value); err == nil { - return parsed, nil - } else { - errs = append(errs, err) - } - } - return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...)) - case SchemaTypeArray: - return nil, fmt.Errorf("cannot parse string for array schema type") - default: - return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType) - } -} diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go deleted file mode 100644 index 1e0d1817..00000000 --- a/event/cmdschema/parse_test.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "bytes" - "encoding/json" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/exbytes" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/event/cmdschema/testdata" -) - -type QuoteParseOutput struct { - Parsed string - Remaining string - Quoted bool -} - -func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error { - var arr []any - if err := json.Unmarshal(data, &arr); err != nil { - return err - } - qpo.Parsed = arr[0].(string) - qpo.Remaining = arr[1].(string) - qpo.Quoted = arr[2].(bool) - return nil -} - -type QuoteParseTestData struct { - Name string `json:"name"` - Input string `json:"input"` - Output QuoteParseOutput `json:"output"` -} - -func loadFile[T any](name string) (into T) { - quoteData := exerrors.Must(testdata.FS.ReadFile(name)) - exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into)) - return -} - -func TestParseQuoted(t *testing.T) { - qptd := loadFile[[]QuoteParseTestData]("parse_quote.json") - for _, test := range qptd { - t.Run(test.Name, func(t *testing.T) { - parsed, remaining, quoted := parseQuoted(test.Input) - assert.Equalf(t, test.Output, QuoteParseOutput{ - Parsed: parsed, - Remaining: remaining, - Quoted: quoted, - }, "Failed with input `%s`", test.Input) - // Note: can't just test that requoted == input, because some inputs - // have unnecessary escapes which won't survive roundtripping - t.Run("roundtrip", func(t *testing.T) { - requoted := quoteString(parsed) + " " + remaining - reparsed, newRemaining, _ := parseQuoted(requoted) - assert.Equal(t, parsed, reparsed) - assert.Equal(t, remaining, newRemaining) - }) - }) - } -} - -type CommandTestData struct { - Spec *EventContent - Tests []*CommandTestUnit -} - -type CommandTestUnit struct { - Name string `json:"name"` - Input string `json:"input"` - Broken string `json:"broken,omitempty"` - Error bool `json:"error"` - Output json.RawMessage `json:"output"` -} - -func compactJSON(input json.RawMessage) json.RawMessage { - var buf bytes.Buffer - exerrors.PanicIfNotNil(json.Compact(&buf, input)) - return buf.Bytes() -} - -func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { - for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) { - t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) { - ctd := loadFile[CommandTestData]("commands/" + cmd.Name()) - for _, test := range ctd.Tests { - outputStr := exbytes.UnsafeString(compactJSON(test.Output)) - t.Run(test.Name, func(t *testing.T) { - if test.Broken != "" { - t.Skip(test.Broken) - } - output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input) - if test.Error { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - if outputStr == "null" { - assert.Nil(t, output) - } else { - assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) - assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) - } - }) - } - }) - } -} diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go deleted file mode 100644 index 98c421fc..00000000 --- a/event/cmdschema/roomid.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "fmt" - "slices" - "strings" - - "maunium.net/go/mautrix/id" -) - -var ParameterSchemaJoinableRoom = Union( - PrimitiveTypeRoomID.Schema(), - PrimitiveTypeRoomAlias.Schema(), -) - -type RoomIDValue struct { - Type PrimitiveType `json:"type"` - RoomID id.RoomID `json:"id"` - Via []string `json:"via,omitempty"` - EventID id.EventID `json:"event_id,omitempty"` -} - -func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) { - switch typedValue := input.(type) { - case map[string]any, json.RawMessage: - var raw json.RawMessage - if raw, err = json.Marshal(input); err != nil { - err = fmt.Errorf("failed to roundtrip room ID value: %w", err) - } else if err = json.Unmarshal(raw, &riv); err != nil { - err = fmt.Errorf("failed to roundtrip room ID value: %w", err) - } - case *RoomIDValue: - riv = typedValue - case RoomIDValue: - riv = &typedValue - default: - err = fmt.Errorf("unsupported type %T for room or event ID", input) - } - return -} - -func (riv *RoomIDValue) String() string { - return riv.URI().String() -} - -func (riv *RoomIDValue) URI() *id.MatrixURI { - if riv == nil { - return nil - } - switch riv.Type { - case PrimitiveTypeRoomID: - return riv.RoomID.URI(riv.Via...) - case PrimitiveTypeEventID: - return riv.RoomID.EventURI(riv.EventID, riv.Via...) - default: - return nil - } -} - -func (riv *RoomIDValue) Equals(other *RoomIDValue) bool { - if riv == nil || other == nil { - return riv == other - } - return riv.Type == other.Type && - riv.RoomID == other.RoomID && - riv.EventID == other.EventID && - slices.Equal(riv.Via, other.Via) -} - -func (riv *RoomIDValue) Validate() error { - if riv == nil { - return fmt.Errorf("value is nil") - } - switch riv.Type { - case PrimitiveTypeRoomID: - if riv.EventID != "" { - return fmt.Errorf("event ID must be empty for room ID type") - } - case PrimitiveTypeEventID: - if !strings.HasPrefix(riv.EventID.String(), "$") { - return fmt.Errorf("event ID not valid: %q", riv.EventID) - } - default: - return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type) - } - for _, via := range riv.Via { - if !id.ValidateServerName(via) { - return fmt.Errorf("invalid server name %q in vias", via) - } - } - sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID) - if sigil != '!' { - return fmt.Errorf("room ID does not start with !: %q", riv.RoomID) - } else if localpart == "" && serverName == "" { - return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID) - } else if serverName != "" && !id.ValidateServerName(serverName) { - return fmt.Errorf("invalid server name %q in room ID", serverName) - } - return nil -} - -func (riv *RoomIDValue) IsValid() bool { - return riv.Validate() == nil -} - -type RoomIDOrString string - -func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error { - if len(data) == 0 { - return fmt.Errorf("empty data for room ID or string") - } - if data[0] == '"' { - var str string - if err := json.Unmarshal(data, &str); err != nil { - return err - } - *ros = RoomIDOrString(str) - return nil - } - var riv RoomIDValue - if err := json.Unmarshal(data, &riv); err != nil { - return err - } else if err = riv.Validate(); err != nil { - return err - } - *ros = RoomIDOrString(riv.String()) - return nil -} diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go deleted file mode 100644 index c5c57c53..00000000 --- a/event/cmdschema/stringify.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "strconv" - "strings" -) - -var quoteEscaper = strings.NewReplacer( - `"`, `\"`, - `\`, `\\`, -) - -const charsToQuote = ` \` + botArrayOpener + botArrayCloser - -func quoteString(val string) string { - if val == "" { - return `""` - } - val = quoteEscaper.Replace(val) - if strings.ContainsAny(val, charsToQuote) { - return `"` + val + `"` - } - return val -} - -func (ec *EventContent) StringifyArgs(args any) string { - var argMap map[string]any - switch typedArgs := args.(type) { - case json.RawMessage: - err := json.Unmarshal(typedArgs, &argMap) - if err != nil { - return "" - } - case map[string]any: - argMap = typedArgs - default: - if b, err := json.Marshal(args); err != nil { - return "" - } else if err = json.Unmarshal(b, &argMap); err != nil { - return "" - } - } - parts := make([]string, 0, len(ec.Parameters)) - for i, param := range ec.Parameters { - isLast := i == len(ec.Parameters)-1 - val := argMap[param.Key] - if val == nil { - val = param.DefaultValue - if val == nil && !param.Optional { - val = param.Schema.GetDefaultValue() - } - } - if val == nil { - continue - } - var stringified string - if param.Schema.SchemaType == SchemaTypeArray { - stringified = arrayArgumentToString(val, isLast) - } else { - stringified = singleArgumentToString(val) - } - if stringified != "" { - parts = append(parts, stringified) - } - } - return strings.Join(parts, " ") -} - -func arrayArgumentToString(val any, isLast bool) string { - valArr, ok := val.([]any) - if !ok { - return "" - } - parts := make([]string, 0, len(valArr)) - for _, elem := range valArr { - stringified := singleArgumentToString(elem) - if stringified != "" { - parts = append(parts, stringified) - } - } - joinedParts := strings.Join(parts, " ") - if isLast && len(parts) > 0 { - return joinedParts - } - return botArrayOpener + joinedParts + botArrayCloser -} - -func singleArgumentToString(val any) string { - switch typedVal := val.(type) { - case string: - return quoteString(typedVal) - case json.Number: - return typedVal.String() - case bool: - return strconv.FormatBool(typedVal) - case int: - return strconv.Itoa(typedVal) - case int64: - return strconv.FormatInt(typedVal, 10) - case float64: - return strconv.FormatInt(int64(typedVal), 10) - case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue: - normalized, err := NormalizeRoomIDValue(typedVal) - if err != nil { - return "" - } - uri := normalized.URI() - if uri == nil { - return "" - } - return quoteString(uri.String()) - default: - return "" - } -} diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json deleted file mode 100644 index e53382db..00000000 --- a/event/cmdschema/testdata/commands.schema.json +++ /dev/null @@ -1,281 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema#", - "$id": "commands.schema.json", - "title": "ParseInput test cases", - "description": "JSON schema for test case files containing command specifications and test cases", - "type": "object", - "required": [ - "spec", - "tests" - ], - "additionalProperties": false, - "properties": { - "spec": { - "title": "MSC4391 Command Description", - "description": "JSON schema defining the structure of a bot command event content", - "type": "object", - "required": [ - "command" - ], - "additionalProperties": false, - "properties": { - "command": { - "type": "string", - "description": "The command name that triggers this bot command" - }, - "aliases": { - "type": "array", - "description": "Alternative names/aliases for this command", - "items": { - "type": "string" - } - }, - "parameters": { - "type": "array", - "description": "List of parameters accepted by this command", - "items": { - "$ref": "#/$defs/Parameter" - } - }, - "description": { - "$ref": "#/$defs/ExtensibleTextContainer", - "description": "Human-readable description of the command" - }, - "fi.mau.tail_parameter": { - "type": "string", - "description": "The key of the parameter that accepts remaining arguments as tail text" - }, - "source": { - "type": "string", - "description": "The user ID of the bot that responds to this command" - } - } - }, - "tests": { - "type": "array", - "description": "Array of test cases for the command", - "items": { - "type": "object", - "description": "A single test case for command parsing", - "required": [ - "name", - "input" - ], - "additionalProperties": false, - "properties": { - "name": { - "type": "string", - "description": "The name of the test case" - }, - "input": { - "type": "string", - "description": "The command input string to parse" - }, - "output": { - "description": "The expected parsed parameter values, or null if the parsing is expected to fail", - "oneOf": [ - { - "type": "object", - "additionalProperties": true - }, - { - "type": "null" - } - ] - }, - "error": { - "type": "boolean", - "description": "Whether parsing should result in an error. May still produce output.", - "default": false - } - } - } - } - }, - "$defs": { - "ExtensibleTextContainer": { - "type": "object", - "description": "Container for text that can have multiple representations", - "required": [ - "m.text" - ], - "properties": { - "m.text": { - "type": "array", - "description": "Array of text representations in different formats", - "items": { - "$ref": "#/$defs/ExtensibleText" - } - } - } - }, - "ExtensibleText": { - "type": "object", - "description": "A text representation with a specific MIME type", - "required": [ - "body" - ], - "properties": { - "body": { - "type": "string", - "description": "The text content" - }, - "mimetype": { - "type": "string", - "description": "The MIME type of the text (e.g., text/plain, text/html)", - "default": "text/plain", - "examples": [ - "text/plain", - "text/html" - ] - } - } - }, - "Parameter": { - "type": "object", - "description": "A parameter definition for a command", - "required": [ - "key", - "schema" - ], - "additionalProperties": false, - "properties": { - "key": { - "type": "string", - "description": "The identifier for this parameter" - }, - "schema": { - "$ref": "#/$defs/ParameterSchema", - "description": "The schema defining the type and structure of this parameter" - }, - "optional": { - "type": "boolean", - "description": "Whether this parameter is optional", - "default": false - }, - "description": { - "$ref": "#/$defs/ExtensibleTextContainer", - "description": "Human-readable description of this parameter" - }, - "fi.mau.default_value": { - "description": "Default value for this parameter if not provided" - } - } - }, - "ParameterSchema": { - "type": "object", - "description": "Schema definition for a parameter value", - "required": [ - "schema_type" - ], - "additionalProperties": false, - "properties": { - "schema_type": { - "type": "string", - "enum": [ - "primitive", - "array", - "union", - "literal" - ], - "description": "The type of schema" - } - }, - "allOf": [ - { - "if": { - "properties": { - "schema_type": { - "const": "primitive" - } - } - }, - "then": { - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "string", - "integer", - "boolean", - "server_name", - "user_id", - "room_id", - "room_alias", - "event_id" - ], - "description": "The primitive type (only for schema_type: primitive)" - } - } - } - }, - { - "if": { - "properties": { - "schema_type": { - "const": "array" - } - } - }, - "then": { - "required": [ - "items" - ], - "properties": { - "items": { - "$ref": "#/$defs/ParameterSchema", - "description": "The schema for array items (only for schema_type: array)" - } - } - } - }, - { - "if": { - "properties": { - "schema_type": { - "const": "union" - } - } - }, - "then": { - "required": [ - "variants" - ], - "properties": { - "variants": { - "type": "array", - "description": "The possible variants (only for schema_type: union)", - "items": { - "$ref": "#/$defs/ParameterSchema" - }, - "minItems": 1 - } - } - } - }, - { - "if": { - "properties": { - "schema_type": { - "const": "literal" - } - } - }, - "then": { - "required": [ - "value" - ], - "properties": { - "value": { - "description": "The literal value (only for schema_type: literal)" - } - } - } - } - ] - } - } -} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json deleted file mode 100644 index 6ce1f4da..00000000 --- a/event/cmdschema/testdata/commands/flags.json +++ /dev/null @@ -1,126 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "flag", - "source": "@testbot", - "parameters": [ - { - "key": "meow", - "schema": { - "schema_type": "primitive", - "type": "string" - } - }, - { - "key": "user", - "schema": { - "schema_type": "primitive", - "type": "user_id" - }, - "optional": true - }, - { - "key": "woof", - "schema": { - "schema_type": "primitive", - "type": "boolean" - }, - "optional": true, - "fi.mau.default_value": false - } - ], - "fi.mau.tail_parameter": "user" - }, - "tests": [ - { - "name": "no flags", - "input": "/flag mrrp", - "output": { - "meow": "mrrp", - "user": null - } - }, - { - "name": "no flags, has tail", - "input": "/flag mrrp @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com" - } - }, - { - "name": "named flag at start", - "input": "/flag --woof=yes mrrp @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "boolean flag without value", - "input": "/flag --woof mrrp @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "user id flag without value", - "input": "/flag --user --woof mrrp", - "error": true, - "output": { - "meow": "mrrp", - "user": null, - "woof": true - } - }, - { - "name": "named flag in the middle", - "input": "/flag mrrp --woof=yes @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "named flag in the middle with different value", - "input": "/flag mrrp --woof=no @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": false - } - }, - { - "name": "all variables named", - "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": false - } - }, - { - "name": "all variables named with quotes", - "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"", - "output": { - "meow": "meow meow mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "invalid value for named parameter", - "input": "/flag --user=meowings mrrp --woof", - "error": true, - "output": { - "meow": "mrrp", - "user": null, - "woof": true - } - } - ] -} diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json deleted file mode 100644 index 1351c292..00000000 --- a/event/cmdschema/testdata/commands/room_id_or_alias.json +++ /dev/null @@ -1,85 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "test room reference", - "source": "@testbot", - "parameters": [ - { - "key": "room", - "schema": { - "schema_type": "union", - "variants": [ - { - "schema_type": "primitive", - "type": "room_id" - }, - { - "schema_type": "primitive", - "type": "room_alias" - } - ] - } - } - ] - }, - "tests": [ - { - "name": "room alias", - "input": "/test room reference #test:matrix.org", - "output": { - "room": "#test:matrix.org" - } - }, - { - "name": "room id", - "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", - "output": { - "room": { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" - } - } - }, - { - "name": "room id matrix.to link", - "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com", - "output": { - "room": { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", - "via": [ - "example.com" - ] - } - } - }, - { - "name": "room id matrix.to link with url encoding", - "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net", - "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly", - "output": { - "room": { - "type": "room_id", - "id": "!#test/room\nversion 11, with @🐈️:maunium.net", - "via": [ - "maunium.net" - ] - } - } - }, - { - "name": "room id matrix: URI", - "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", - "output": { - "room": { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - "via": [ - "maunium.net", - "matrix.org" - ] - } - } - } - ] -} diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json deleted file mode 100644 index aa266054..00000000 --- a/event/cmdschema/testdata/commands/room_reference_list.json +++ /dev/null @@ -1,106 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "test room reference", - "source": "@testbot", - "parameters": [ - { - "key": "rooms", - "schema": { - "schema_type": "array", - "items": { - "schema_type": "union", - "variants": [ - { - "schema_type": "primitive", - "type": "room_id" - }, - { - "schema_type": "primitive", - "type": "room_alias" - } - ] - } - } - } - ] - }, - "tests": [ - { - "name": "room alias", - "input": "/test room reference #test:matrix.org", - "output": { - "rooms": [ - "#test:matrix.org" - ] - } - }, - { - "name": "room id", - "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" - } - ] - } - }, - { - "name": "two room ids", - "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ" - }, - { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" - } - ] - } - }, - { - "name": "room id matrix: URI", - "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - "via": [ - "maunium.net", - "matrix.org" - ] - } - ] - } - }, - { - "name": "room id matrix: URI and matrix.to URL", - "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", - "via": [ - "example.com" - ] - }, - { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - "via": [ - "maunium.net", - "matrix.org" - ] - } - ] - } - } - ] -} diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json deleted file mode 100644 index 94667323..00000000 --- a/event/cmdschema/testdata/commands/simple.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "test simple", - "source": "@testbot", - "parameters": [ - { - "key": "meow", - "schema": { - "schema_type": "primitive", - "type": "string" - } - } - ] - }, - "tests": [ - { - "name": "success", - "input": "/test simple mrrp", - "output": { - "meow": "mrrp" - } - }, - { - "name": "directed success", - "input": "/test simple@testbot mrrp", - "output": { - "meow": "mrrp" - } - }, - { - "name": "missing parameter", - "input": "/test simple", - "error": true, - "output": { - "meow": "" - } - }, - { - "name": "directed at another bot", - "input": "/test simple@anotherbot mrrp", - "error": false, - "output": null - } - ] -} diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json deleted file mode 100644 index 9782f8ec..00000000 --- a/event/cmdschema/testdata/commands/tail.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "tail", - "source": "@testbot", - "parameters": [ - { - "key": "meow", - "schema": { - "schema_type": "primitive", - "type": "string" - } - }, - { - "key": "reason", - "schema": { - "schema_type": "primitive", - "type": "string" - }, - "optional": true - }, - { - "key": "woof", - "schema": { - "schema_type": "primitive", - "type": "boolean" - }, - "optional": true - } - ], - "fi.mau.tail_parameter": "reason" - }, - "tests": [ - { - "name": "no tail or flag", - "input": "/tail mrrp", - "output": { - "meow": "mrrp", - "reason": "" - } - }, - { - "name": "tail, no flag", - "input": "/tail mrrp meow meow", - "output": { - "meow": "mrrp", - "reason": "meow meow" - } - }, - { - "name": "flag before tail", - "input": "/tail mrrp --woof meow meow", - "output": { - "meow": "mrrp", - "reason": "meow meow", - "woof": true - } - } - ] -} diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json deleted file mode 100644 index 8f52b7f5..00000000 --- a/event/cmdschema/testdata/parse_quote.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - {"name": "empty string", "input": "", "output": ["", "", false]}, - {"name": "single word", "input": "meow", "output": ["meow", "", false]}, - {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, - {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, - {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, - {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, - {"name": "only spaces", "input": " ", "output": ["", "", false]}, - {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]}, - {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]}, - {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, - {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, - {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, - {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]}, - {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]}, - {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]}, - {"name": "just opening quote", "input": "\"", "output": ["", "", true]}, - {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]}, - {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, - {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, - {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, - {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]}, - {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]}, - {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]}, - {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, - {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, - {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, - {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]}, - {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]} -] diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json deleted file mode 100644 index 9f249116..00000000 --- a/event/cmdschema/testdata/parse_quote.schema.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema#", - "$id": "parse_quote.schema.json", - "title": "parseQuote test cases", - "description": "Test cases for the parseQuoted function", - "type": "array", - "items": { - "type": "object", - "required": [ - "name", - "input", - "output" - ], - "properties": { - "name": { - "type": "string", - "description": "Name of the test case" - }, - "input": { - "type": "string", - "description": "Input string to be parsed" - }, - "output": { - "type": "array", - "description": "Expected output of parsing: [first word, remaining text, was quoted]", - "minItems": 3, - "maxItems": 3, - "prefixItems": [ - { - "type": "string", - "description": "First parsed word" - }, - { - "type": "string", - "description": "Remaining text after the first word" - }, - { - "type": "boolean", - "description": "Whether the first word was quoted" - } - ] - } - }, - "additionalProperties": false - } -} diff --git a/event/content.go b/event/content.go index 814aeec4..e0026e9e 100644 --- a/event/content.go +++ b/event/content.go @@ -18,7 +18,6 @@ import ( // This is used by Content.ParseRaw() for creating the correct type of struct. var TypeMap = map[Type]reflect.Type{ StateMember: reflect.TypeOf(MemberEventContent{}), - StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}), StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}), StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}), StateRoomName: reflect.TypeOf(RoomNameEventContent{}), @@ -39,20 +38,9 @@ var TypeMap = map[Type]reflect.Type{ StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}), StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), - - StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), - StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), - - StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), - StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), - StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}), - StateUnstablePolicyRoom: reflect.TypeOf(ModPolicyContent{}), - StateUnstablePolicyServer: reflect.TypeOf(ModPolicyContent{}), - StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}), + StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), - StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), - StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), @@ -60,14 +48,7 @@ var TypeMap = map[Type]reflect.Type{ EventRedaction: reflect.TypeOf(RedactionEventContent{}), EventReaction: reflect.TypeOf(ReactionEventContent{}), - EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), - EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), - - BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), - BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), - BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), - BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}), - BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), @@ -76,11 +57,9 @@ var TypeMap = map[Type]reflect.Type{ AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), - EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), - EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), - EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), - EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}), - BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}), + EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), + EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), + EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), @@ -132,7 +111,7 @@ var TypeMap = map[Type]reflect.Type{ // When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged // with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging, // rather than overriding the whole object with the one in Raw). -// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. +// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. type Content struct { VeryRaw json.RawMessage Raw map[string]interface{} @@ -199,13 +178,6 @@ func IsUnsupportedContentType(err error) bool { var ErrContentAlreadyParsed = errors.New("content is already parsed") var ErrUnsupportedContentType = errors.New("unsupported event type") -func (content *Content) GetRaw() map[string]interface{} { - if content.Raw == nil { - content.Raw = make(map[string]interface{}) - } - return content.Raw -} - func (content *Content) ParseRaw(evtType Type) error { if content.Parsed != nil { return ErrContentAlreadyParsed diff --git a/event/delayed.go b/event/delayed.go deleted file mode 100644 index fefb62af..00000000 --- a/event/delayed.go +++ /dev/null @@ -1,70 +0,0 @@ -package event - -import ( - "encoding/json" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/id" -) - -type ScheduledDelayedEvent struct { - DelayID id.DelayID `json:"delay_id"` - RoomID id.RoomID `json:"room_id"` - Type Type `json:"type"` - StateKey *string `json:"state_key,omitempty"` - Delay int64 `json:"delay"` - RunningSince jsontime.UnixMilli `json:"running_since"` - Content Content `json:"content"` -} - -func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) { - evt := &Event{ - ID: eventID, - RoomID: e.RoomID, - Type: e.Type, - StateKey: e.StateKey, - Content: e.Content, - Timestamp: ts.UnixMilli(), - } - return evt, evt.Content.ParseRaw(evt.Type) -} - -type FinalisedDelayedEvent struct { - DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"` - Outcome DelayOutcome `json:"outcome"` - Reason DelayReason `json:"reason"` - Error json.RawMessage `json:"error,omitempty"` - EventID id.EventID `json:"event_id,omitempty"` - Timestamp jsontime.UnixMilli `json:"origin_server_ts"` -} - -type DelayStatus string - -var ( - DelayStatusScheduled DelayStatus = "scheduled" - DelayStatusFinalised DelayStatus = "finalised" -) - -type DelayAction string - -var ( - DelayActionSend DelayAction = "send" - DelayActionCancel DelayAction = "cancel" - DelayActionRestart DelayAction = "restart" -) - -type DelayOutcome string - -var ( - DelayOutcomeSend DelayOutcome = "send" - DelayOutcomeCancel DelayOutcome = "cancel" -) - -type DelayReason string - -var ( - DelayReasonAction DelayReason = "action" - DelayReasonError DelayReason = "error" - DelayReasonDelay DelayReason = "delay" -) diff --git a/event/encryption.go b/event/encryption.go index c60cb91a..cf9c2814 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString) + return id.InputNotJSONString } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } @@ -132,9 +132,8 @@ type RoomKeyRequestEventContent struct { type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` - SessionID id.SessionID `json:"session_id"` - // Deprecated: Matrix v1.3 SenderKey id.SenderKey `json:"sender_key"` + SessionID id.SessionID `json:"session_id"` } type RoomKeyWithheldCode string diff --git a/event/events.go b/event/events.go index 72c1e161..4653a531 100644 --- a/event/events.go +++ b/event/events.go @@ -118,9 +118,6 @@ type MautrixInfo struct { DecryptionDuration time.Duration CheckpointSent bool - // When using MSC4222 and the state_after field, this field is set - // for timeline events to indicate they shouldn't update room state. - IgnoreState bool } func (evt *Event) GetStateKey() string { @@ -130,29 +127,28 @@ func (evt *Event) GetStateKey() string { return "" } +type StrippedState struct { + Content Content `json:"content"` + Type Type `json:"type"` + StateKey string `json:"state_key"` + Sender id.UserID `json:"sender"` +} + type Unsigned struct { - PrevContent *Content `json:"prev_content,omitempty"` - PrevSender id.UserID `json:"prev_sender,omitempty"` - Membership Membership `json:"membership,omitempty"` - ReplacesState id.EventID `json:"replaces_state,omitempty"` - Age int64 `json:"age,omitempty"` - TransactionID string `json:"transaction_id,omitempty"` - Relations *Relations `json:"m.relations,omitempty"` - RedactedBecause *Event `json:"redacted_because,omitempty"` - InviteRoomState []*Event `json:"invite_room_state,omitempty"` + PrevContent *Content `json:"prev_content,omitempty"` + PrevSender id.UserID `json:"prev_sender,omitempty"` + ReplacesState id.EventID `json:"replaces_state,omitempty"` + Age int64 `json:"age,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + Relations *Relations `json:"m.relations,omitempty"` + RedactedBecause *Event `json:"redacted_because,omitempty"` + InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` - BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` - BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"` - BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` - BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` - - ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"` - ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"` + BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` } func (us *Unsigned) IsEmpty() bool { - return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" && + return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && - us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() && - !us.ElementSoftFailed + us.BeeperHSOrder == 0 } diff --git a/event/member.go b/event/member.go index 9956a36b..ebafdcb7 100644 --- a/event/member.go +++ b/event/member.go @@ -7,6 +7,8 @@ package event import ( + "encoding/json" + "maunium.net/go/mautrix/id" ) @@ -33,37 +35,19 @@ const ( // MemberEventContent represents the content of a m.room.member state event. // https://spec.matrix.org/v1.2/client-server-api/#mroommember type MemberEventContent struct { - Membership Membership `json:"membership"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Displayname string `json:"displayname,omitempty"` - IsDirect bool `json:"is_direct,omitempty"` - ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` - Reason string `json:"reason,omitempty"` - JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"` - MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` - - MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` -} - -type SignedThirdPartyInvite struct { - Token string `json:"token"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` - MXID string `json:"mxid"` + Membership Membership `json:"membership"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Displayname string `json:"displayname,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` + Reason string `json:"reason,omitempty"` } type ThirdPartyInvite struct { - DisplayName string `json:"display_name"` - Signed SignedThirdPartyInvite `json:"signed"` -} - -type ThirdPartyInviteEventContent struct { - DisplayName string `json:"display_name"` - KeyValidityURL string `json:"key_validity_url"` - PublicKey id.Ed25519 `json:"public_key"` - PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"` -} - -type ThirdPartyInviteKey struct { - KeyValidityURL string `json:"key_validity_url,omitempty"` - PublicKey id.Ed25519 `json:"public_key"` + DisplayName string `json:"display_name"` + Signed struct { + Token string `json:"token"` + Signatures json.RawMessage `json:"signatures"` + MXID string `json:"mxid"` + } } diff --git a/event/message.go b/event/message.go index 3fb3dc82..3c6edfdd 100644 --- a/event/message.go +++ b/event/message.go @@ -8,11 +8,12 @@ package event import ( "encoding/json" - "html" "slices" "strconv" "strings" + "golang.org/x/net/html" + "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/id" ) @@ -32,7 +33,7 @@ func (mt MessageType) IsText() bool { func (mt MessageType) IsMedia() bool { switch mt { - case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker: + case MsgImage, MsgVideo, MsgAudio, MsgFile, MessageType(EventSticker.Type): return true default: return false @@ -135,63 +136,11 @@ type MessageEventContent struct { BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` - BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` - BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"` - MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` - - MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"` -} - -func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { - switch content.MsgType { - case CapMsgSticker: - return CapMsgSticker - case "": - if content.URL != "" || content.File != nil { - return CapMsgSticker - } - case MsgImage: - return MsgImage - case MsgAudio: - if content.MSC3245Voice != nil { - return CapMsgVoice - } - return MsgAudio - case MsgVideo: - if content.Info != nil && content.Info.MauGIF { - return CapMsgGIF - } - return MsgVideo - case MsgFile: - return MsgFile - } - return "" -} - -func (content *MessageEventContent) GetFileName() string { - if content.FileName != "" { - return content.FileName - } - return content.Body -} - -func (content *MessageEventContent) GetCaption() string { - if content.FileName != "" && content.Body != "" && content.Body != content.FileName { - return content.Body - } - return "" -} - -func (content *MessageEventContent) GetFormattedCaption() string { - if content.Format == FormatHTML && content.FormattedBody != "" { - return content.FormattedBody - } - return "" } func (content *MessageEventContent) GetRelatesTo() *RelatesTo { @@ -215,7 +164,6 @@ func (content *MessageEventContent) SetEdit(original id.EventID) { content.RelatesTo = (&RelatesTo{}).SetReplace(original) if content.MsgType == MsgText || content.MsgType == MsgNotice { content.Body = "* " + content.Body - content.Mentions = &Mentions{} if content.Format == FormatHTML && len(content.FormattedBody) > 0 { content.FormattedBody = "* " + content.FormattedBody } @@ -272,50 +220,24 @@ func (m *Mentions) Add(userID id.UserID) { } } -func (m *Mentions) Has(userID id.UserID) bool { - return m != nil && slices.Contains(m.UserIDs, userID) -} - -func (m *Mentions) Merge(other *Mentions) *Mentions { - if m == nil { - return other - } else if other == nil { - return m - } - return &Mentions{ - UserIDs: slices.Concat(m.UserIDs, other.UserIDs), - Room: m.Room || other.Room, - } -} - -type MSC4391BotCommandInputCustom[T any] struct { - Command string `json:"command"` - Arguments T `json:"arguments,omitempty"` -} - -type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage] - type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` } type FileInfo struct { - MimeType string - ThumbnailInfo *FileInfo - ThumbnailURL id.ContentURIString - ThumbnailFile *EncryptedFileInfo + MimeType string `json:"mimetype,omitempty"` + ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"` + ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` + ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` - Blurhash string - AnoaBlurhash string + Blurhash string `json:"blurhash,omitempty"` + AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` - MauGIF bool - IsAnimated bool - - Width int - Height int - Duration int - Size int + Width int `json:"-"` + Height int `json:"-"` + Duration int `json:"-"` + Size int `json:"-"` } type serializableFileInfo struct { @@ -327,9 +249,6 @@ type serializableFileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` - MauGIF bool `json:"fi.mau.gif,omitempty"` - IsAnimated bool `json:"is_animated,omitempty"` - Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` Duration json.Number `json:"duration,omitempty"` @@ -346,9 +265,6 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, - MauGIF: fileInfo.MauGIF, - IsAnimated: fileInfo.IsAnimated, - Blurhash: fileInfo.Blurhash, AnoaBlurhash: fileInfo.AnoaBlurhash, } @@ -377,8 +293,6 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { MimeType: sfi.MimeType, ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, - MauGIF: sfi.MauGIF, - IsAnimated: sfi.IsAnimated, Blurhash: sfi.Blurhash, AnoaBlurhash: sfi.AnoaBlurhash, } diff --git a/event/message_test.go b/event/message_test.go index c721df35..562a6622 100644 --- a/event/message_test.go +++ b/event/message_test.go @@ -33,7 +33,7 @@ const invalidMessageEvent = `{ func TestMessageEventContent__ParseInvalid(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(invalidMessageEvent), &evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) { assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) - assert.Error(t, err) + assert.NotNil(t, err) } const messageEvent = `{ @@ -68,7 +68,7 @@ const messageEvent = `{ func TestMessageEventContent__ParseEdit(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(messageEvent), &evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -110,7 +110,7 @@ const imageMessageEvent = `{ func TestMessageEventContent__ParseMedia(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(imageMessageEvent), &evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) { content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgImage, content.MsgType) parsedURL, err := content.URL.Parse() - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL) assert.Nil(t, content.NewContent) assert.Equal(t, "image/png", content.GetInfo().MimeType) @@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}` func TestMessageEventContent__Marshal(t *testing.T) { data, err := json.Marshal(parsedMessage) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, expectedMarshalResult, string(data)) } @@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun func TestMessageEventContent__Marshal_Custom(t *testing.T) { data, err := json.Marshal(customParsedMessage) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, expectedCustomMarshalResult, string(data)) } diff --git a/event/poll.go b/event/poll.go deleted file mode 100644 index 9082f65e..00000000 --- a/event/poll.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package event - -type PollResponseEventContent struct { - RelatesTo RelatesTo `json:"m.relates_to"` - Response struct { - Answers []string `json:"answers"` - } `json:"org.matrix.msc3381.poll.response"` -} - -func (content *PollResponseEventContent) GetRelatesTo() *RelatesTo { - return &content.RelatesTo -} - -func (content *PollResponseEventContent) OptionalGetRelatesTo() *RelatesTo { - if content.RelatesTo.Type == "" { - return nil - } - return &content.RelatesTo -} - -func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) { - content.RelatesTo = *rel -} - -type MSC1767Message struct { - Text string `json:"org.matrix.msc1767.text,omitempty"` - HTML string `json:"org.matrix.msc1767.html,omitempty"` - Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"` -} - -type PollStartEventContent struct { - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` - Mentions *Mentions `json:"m.mentions,omitempty"` - PollStart struct { - Kind string `json:"kind"` - MaxSelections int `json:"max_selections"` - Question MSC1767Message `json:"question"` - Answers []struct { - ID string `json:"id"` - MSC1767Message - } `json:"answers"` - } `json:"org.matrix.msc3381.poll.start"` -} - -func (content *PollStartEventContent) GetRelatesTo() *RelatesTo { - if content.RelatesTo == nil { - content.RelatesTo = &RelatesTo{} - } - return content.RelatesTo -} - -func (content *PollStartEventContent) OptionalGetRelatesTo() *RelatesTo { - return content.RelatesTo -} - -func (content *PollStartEventContent) SetRelatesTo(rel *RelatesTo) { - content.RelatesTo = rel -} diff --git a/event/powerlevels.go b/event/powerlevels.go index 668eb6d3..2f4d4573 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -7,8 +7,6 @@ package event import ( - "math" - "slices" "sync" "go.mau.fi/util/ptr" @@ -28,9 +26,6 @@ type PowerLevelsEventContent struct { Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` - beeperEphemeralLock sync.RWMutex - BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"` - Notifications *NotificationPowerLevels `json:"notifications,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` @@ -39,12 +34,6 @@ type PowerLevelsEventContent struct { KickPtr *int `json:"kick,omitempty"` BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` - - BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"` - - // This is not a part of power levels, it's added by mautrix-go internally in certain places - // in order to detect creator power accurately. - CreateEvent *Event `json:"-"` } func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { @@ -56,7 +45,6 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { UsersDefault: pl.UsersDefault, Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, - BeeperEphemeral: maps.Clone(pl.BeeperEphemeral), StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), @@ -65,10 +53,6 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { KickPtr: ptr.Clone(pl.KickPtr), BanPtr: ptr.Clone(pl.BanPtr), RedactPtr: ptr.Clone(pl.RedactPtr), - - BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr), - - CreateEvent: pl.CreateEvent, } } @@ -127,17 +111,7 @@ func (pl *PowerLevelsEventContent) StateDefault() int { return 50 } -func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int { - if pl.BeeperEphemeralDefaultPtr != nil { - return *pl.BeeperEphemeralDefaultPtr - } - return pl.EventsDefault -} - func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { - if pl.isCreator(userID) { - return math.MaxInt - } pl.usersLock.RLock() defer pl.usersLock.RUnlock() level, ok := pl.Users[userID] @@ -147,19 +121,9 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { return level } -const maxPL = 1<<53 - 1 - func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() - if pl.isCreator(userID) { - return - } - if level == math.MaxInt && maxPL < math.MaxInt { - // Hack to avoid breaking on 32-bit systems (they're only slightly supported) - x := int64(maxPL) - level = int(x) - } if level == pl.UsersDefault { delete(pl.Users, userID) } else { @@ -174,24 +138,9 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int) return pl.EnsureUserLevelAs("", target, level) } -func (pl *PowerLevelsEventContent) createContent() *CreateEventContent { - if pl.CreateEvent == nil { - return &CreateEventContent{} - } - return pl.CreateEvent.Content.AsCreate() -} - -func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool { - cc := pl.createContent() - return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID)) -} - func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool { - if pl.isCreator(target) { - return false - } existingLevel := pl.GetUserLevel(target) - if actor != "" && !pl.isCreator(actor) { + if actor != "" { actorLevel := pl.GetUserLevel(actor) if actorLevel <= existingLevel || actorLevel < level { return false @@ -217,29 +166,6 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { return level } -func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int { - pl.beeperEphemeralLock.RLock() - defer pl.beeperEphemeralLock.RUnlock() - level, ok := pl.BeeperEphemeral[eventType.String()] - if !ok { - return pl.BeeperEphemeralDefault() - } - return level -} - -func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) { - pl.beeperEphemeralLock.Lock() - defer pl.beeperEphemeralLock.Unlock() - if level == pl.BeeperEphemeralDefault() { - delete(pl.BeeperEphemeral, eventType.String()) - } else { - if pl.BeeperEphemeral == nil { - pl.BeeperEphemeral = make(map[string]int) - } - pl.BeeperEphemeral[eventType.String()] = level - } -} - func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() @@ -259,7 +185,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool { existingLevel := pl.GetEventLevel(eventType) - if actor != "" && !pl.isCreator(actor) { + if actor != "" { actorLevel := pl.GetUserLevel(actor) if existingLevel > actorLevel || level > actorLevel { return false diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go deleted file mode 100644 index f5861583..00000000 --- a/event/powerlevels_ephemeral_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package event_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/event" -) - -func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) { - pl := &event.PowerLevelsEventContent{ - EventsDefault: 45, - } - - assert.Equal(t, 45, pl.BeeperEphemeralDefault()) - - override := 60 - pl.BeeperEphemeralDefaultPtr = &override - assert.Equal(t, 60, pl.BeeperEphemeralDefault()) -} - -func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) { - pl := &event.PowerLevelsEventContent{ - EventsDefault: 25, - } - evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} - - assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType)) - - pl.SetBeeperEphemeralLevel(evtType, 50) - assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType)) - require.NotNil(t, pl.BeeperEphemeral) - assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()]) - - pl.SetBeeperEphemeralLevel(evtType, 25) - _, exists := pl.BeeperEphemeral[evtType.String()] - assert.False(t, exists) -} - -func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) { - override := 70 - pl := &event.PowerLevelsEventContent{ - EventsDefault: 35, - BeeperEphemeral: map[string]int{"com.example.ephemeral": 90}, - BeeperEphemeralDefaultPtr: &override, - } - - cloned := pl.Clone() - require.NotNil(t, cloned) - require.NotNil(t, cloned.BeeperEphemeralDefaultPtr) - assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr) - assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"]) - - cloned.BeeperEphemeral["com.example.ephemeral"] = 99 - *cloned.BeeperEphemeralDefaultPtr = 71 - - assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"]) - assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr) -} diff --git a/event/profile.go b/event/profile.go new file mode 100644 index 00000000..6dc4314a --- /dev/null +++ b/event/profile.go @@ -0,0 +1,10 @@ +package event + +import "maunium.net/go/mautrix/id" + +type BeeperPerMessageProfile struct { + ID string `json:"id"` + Displayname string `json:"displayname,omitempty"` + AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` + AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` +} diff --git a/event/relations.go b/event/relations.go index 2316cbc7..ea40cc06 100644 --- a/event/relations.go +++ b/event/relations.go @@ -15,11 +15,10 @@ import ( type RelationType string const ( - RelReplace RelationType = "m.replace" - RelReference RelationType = "m.reference" - RelAnnotation RelationType = "m.annotation" - RelThread RelationType = "m.thread" - RelBeeperTranscription RelationType = "com.beeper.transcription" + RelReplace RelationType = "m.replace" + RelReference RelationType = "m.reference" + RelAnnotation RelationType = "m.annotation" + RelThread RelationType = "m.thread" ) type RelatesTo struct { @@ -34,7 +33,7 @@ type RelatesTo struct { type InReplyTo struct { EventID id.EventID `json:"event_id,omitempty"` - UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"` + UnstableRoomID id.RoomID `json:"room_id,omitempty"` } func (rel *RelatesTo) Copy() *RelatesTo { @@ -101,10 +100,6 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo { } func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo { - if rel.Type != RelThread { - rel.Type = "" - rel.EventID = "" - } rel.InReplyTo = &InReplyTo{EventID: mxid} rel.IsFallingBack = false return rel diff --git a/event/reply.go b/event/reply.go index 5f55bb80..73f8cfc7 100644 --- a/event/reply.go +++ b/event/reply.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,7 @@ package event import ( + "fmt" "regexp" "strings" @@ -32,13 +33,12 @@ func TrimReplyFallbackText(text string) string { } func (content *MessageEventContent) RemoveReplyFallback() { - if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML { - origHTML := content.FormattedBody - content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) - if content.FormattedBody != origHTML { - content.Body = TrimReplyFallbackText(content.Body) - content.replyFallbackRemoved = true + if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved { + if content.Format == FormatHTML { + content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) } + content.Body = TrimReplyFallbackText(content.Body) + content.replyFallbackRemoved = true } } @@ -47,28 +47,52 @@ func (content *MessageEventContent) GetReplyTo() id.EventID { return content.RelatesTo.GetReplyTo() } -func (content *MessageEventContent) SetReply(inReplyTo *Event) { - if content.RelatesTo == nil { - content.RelatesTo = &RelatesTo{} +const ReplyFormat = `
      In reply to %s
      %s
      ` + +func (evt *Event) GenerateReplyFallbackHTML() string { + parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) + if !ok { + return "" } - content.RelatesTo.SetReplyTo(inReplyTo.ID) - if content.Mentions == nil { - content.Mentions = &Mentions{} + parsedContent.RemoveReplyFallback() + body := parsedContent.FormattedBody + if len(body) == 0 { + body = TextToHTML(parsedContent.Body) } - content.Mentions.Add(inReplyTo.Sender) + + senderDisplayName := evt.Sender + + return fmt.Sprintf(ReplyFormat, evt.RoomID, evt.ID, evt.Sender, senderDisplayName, body) } -func (content *MessageEventContent) SetThread(inReplyTo *Event) { - root := inReplyTo.ID - relatable, ok := inReplyTo.Content.Parsed.(Relatable) - if ok { - targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent() - if targetRoot != "" { - root = targetRoot - } +func (evt *Event) GenerateReplyFallbackText() string { + parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) + if !ok { + return "" } - if content.RelatesTo == nil { - content.RelatesTo = &RelatesTo{} + parsedContent.RemoveReplyFallback() + body := parsedContent.Body + lines := strings.Split(strings.TrimSpace(body), "\n") + firstLine, lines := lines[0], lines[1:] + + senderDisplayName := evt.Sender + + var fallbackText strings.Builder + _, _ = fmt.Fprintf(&fallbackText, "> <%s> %s", senderDisplayName, firstLine) + for _, line := range lines { + _, _ = fmt.Fprintf(&fallbackText, "\n> %s", line) + } + fallbackText.WriteString("\n\n") + return fallbackText.String() +} + +func (content *MessageEventContent) SetReply(inReplyTo *Event) { + content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID) + + if content.MsgType == MsgText || content.MsgType == MsgNotice { + content.EnsureHasHTML() + content.FormattedBody = inReplyTo.GenerateReplyFallbackHTML() + content.FormattedBody + content.Body = inReplyTo.GenerateReplyFallbackText() + content.Body + content.replyFallbackRemoved = false } - content.RelatesTo.SetThread(root, inReplyTo.ID) } diff --git a/event/state.go b/event/state.go index ace170a5..6a067cae 100644 --- a/event/state.go +++ b/event/state.go @@ -7,12 +7,6 @@ package event import ( - "encoding/base64" - "encoding/json" - "slices" - - "go.mau.fi/util/jsontime" - "maunium.net/go/mautrix/id" ) @@ -32,9 +26,8 @@ type RoomNameEventContent struct { // RoomAvatarEventContent represents the content of a m.room.avatar state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomavatar type RoomAvatarEventContent struct { - URL id.ContentURIString `json:"url,omitempty"` - Info *FileInfo `json:"info,omitempty"` - MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` + URL id.ContentURIString `json:"url"` + Info *FileInfo `json:"info,omitempty"` } // ServerACLEventContent represents the content of a m.room.server_acl state event. @@ -48,52 +41,7 @@ type ServerACLEventContent struct { // TopicEventContent represents the content of a m.room.topic state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomtopic type TopicEventContent struct { - Topic string `json:"topic"` - ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"` -} - -// ExtensibleTopic represents the contents of the m.topic field within the -// m.room.topic state event as described in [MSC3765]. -// -// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 -type ExtensibleTopic = ExtensibleTextContainer - -type ExtensibleTextContainer struct { - Text []ExtensibleText `json:"m.text"` -} - -func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool { - if c == nil || description == nil { - return c == description - } - return slices.Equal(c.Text, description.Text) -} - -func MakeExtensibleText(text string) *ExtensibleTextContainer { - return &ExtensibleTextContainer{ - Text: []ExtensibleText{{ - Body: text, - MimeType: "text/plain", - }}, - } -} - -func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer { - return &ExtensibleTextContainer{ - Text: []ExtensibleText{{ - Body: plaintext, - MimeType: "text/plain", - }, { - Body: html, - MimeType: "text/html", - }}, - } -} - -// ExtensibleText represents the contents of an m.text field. -type ExtensibleText struct { - MimeType string `json:"mimetype,omitempty"` - Body string `json:"body"` + Topic string `json:"topic"` } // TombstoneEventContent represents the content of a m.room.tombstone state event. @@ -103,64 +51,35 @@ type TombstoneEventContent struct { ReplacementRoom id.RoomID `json:"replacement_room"` } -func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID { - if tec == nil { - return "" - } - return tec.ReplacementRoom -} - type Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` } -// Deprecated: use id.RoomVersion instead -type RoomVersion = id.RoomVersion +type RoomVersion string -// Deprecated: use id.RoomVX constants instead const ( - RoomV1 = id.RoomV1 - RoomV2 = id.RoomV2 - RoomV3 = id.RoomV3 - RoomV4 = id.RoomV4 - RoomV5 = id.RoomV5 - RoomV6 = id.RoomV6 - RoomV7 = id.RoomV7 - RoomV8 = id.RoomV8 - RoomV9 = id.RoomV9 - RoomV10 = id.RoomV10 - RoomV11 = id.RoomV11 - RoomV12 = id.RoomV12 + RoomV1 RoomVersion = "1" + RoomV2 RoomVersion = "2" + RoomV3 RoomVersion = "3" + RoomV4 RoomVersion = "4" + RoomV5 RoomVersion = "5" + RoomV6 RoomVersion = "6" + RoomV7 RoomVersion = "7" + RoomV8 RoomVersion = "8" + RoomV9 RoomVersion = "9" + RoomV10 RoomVersion = "10" + RoomV11 RoomVersion = "11" ) // CreateEventContent represents the content of a m.room.create state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { - Type RoomType `json:"type,omitempty"` - Federate *bool `json:"m.federate,omitempty"` - RoomVersion id.RoomVersion `json:"room_version,omitempty"` - Predecessor *Predecessor `json:"predecessor,omitempty"` - - // Room v12+ only - AdditionalCreators []id.UserID `json:"additional_creators,omitempty"` - - // Deprecated: use the event sender instead - Creator id.UserID `json:"creator,omitempty"` -} - -func (cec *CreateEventContent) GetPredecessor() (p Predecessor) { - if cec != nil && cec.Predecessor != nil { - p = *cec.Predecessor - } - return -} - -func (cec *CreateEventContent) SupportsCreatorPower() bool { - if cec == nil { - return false - } - return cec.RoomVersion.PrivilegedRoomCreators() + Type RoomType `json:"type,omitempty"` + Creator id.UserID `json:"creator,omitempty"` + Federate bool `json:"m.federate,omitempty"` + RoomVersion RoomVersion `json:"room_version,omitempty"` + Predecessor *Predecessor `json:"predecessor,omitempty"` } // JoinRule specifies how open a room is to new members. @@ -237,9 +156,6 @@ type BridgeInfoSection struct { DisplayName string `json:"displayname,omitempty"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` - - Receiver string `json:"fi.mau.receiver,omitempty"` - MessageRequest bool `json:"com.beeper.message_request,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. @@ -253,32 +169,6 @@ type BridgeEventContent struct { BeeperRoomType string `json:"com.beeper.room_type,omitempty"` BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` - - TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` - TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"` -} - -// DisappearingType represents the type of a disappearing message timer. -type DisappearingType string - -const ( - DisappearingTypeNone DisappearingType = "" - DisappearingTypeAfterRead DisappearingType = "after_read" - DisappearingTypeAfterSend DisappearingType = "after_send" -) - -type BeeperDisappearingTimer struct { - Type DisappearingType `json:"type"` - Timer jsontime.Milliseconds `json:"timer"` -} - -type marshalableBeeperDisappearingTimer BeeperDisappearingTimer - -func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) { - if bdt == nil || bdt.Type == DisappearingTypeNone { - return []byte("{}"), nil - } - return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt)) } type SpaceChildEventContent struct { @@ -292,66 +182,20 @@ type SpaceParentEventContent struct { Canonical bool `json:"canonical,omitempty"` } -type PolicyRecommendation string - -const ( - PolicyRecommendationBan PolicyRecommendation = "m.ban" - PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown" - PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" - PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" -) - -type PolicyHashes struct { - SHA256 string `json:"sha256"` -} - -func (ph *PolicyHashes) DecodeSHA256() *[32]byte { - if ph == nil || ph.SHA256 == "" { - return nil - } - decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256) - if len(decoded) == 32 { - return (*[32]byte)(decoded) - } - return nil -} - // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { - Entity string `json:"entity,omitempty"` - Reason string `json:"reason"` - Recommendation PolicyRecommendation `json:"recommendation"` - UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"` + Entity string `json:"entity"` + Reason string `json:"reason"` + Recommendation string `json:"recommendation"` } -func (mpc *ModPolicyContent) EntityOrHash() string { - if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" { - return mpc.UnstableHashes.SHA256 - } - return mpc.Entity +// Deprecated: MSC2716 has been abandoned +type InsertionMarkerContent struct { + InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` + Timestamp int64 `json:"com.beeper.timestamp,omitempty"` } type ElementFunctionalMembersContent struct { ServiceMembers []id.UserID `json:"service_members"` } - -func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { - if slices.Contains(efmc.ServiceMembers, mxid) { - return false - } - efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) - return true -} - -type PolicyServerPublicKeys struct { - Ed25519 id.Ed25519 `json:"ed25519,omitempty"` -} - -type RoomPolicyEventContent struct { - Via string `json:"via,omitempty"` - PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` - - // Deprecated, only for legacy use - PublicKey id.Ed25519 `json:"public_key,omitempty"` -} diff --git a/event/type.go b/event/type.go index 80b86728..162e2ce7 100644 --- a/event/type.go +++ b/event/type.go @@ -108,14 +108,13 @@ func (et *Type) IsCustom() bool { func (et *Type) GuessClass() TypeClass { switch et.Type { - case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type, + case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, - StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: + StateInsertionMarker.Type, StateElementFunctionalMembers.Type: return StateEventType - case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type: + case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, @@ -127,8 +126,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, - CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, - EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type: + CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -151,7 +149,7 @@ func (et *Type) MarshalJSON() ([]byte, error) { return json.Marshal(&et.Type) } -func (et *Type) UnmarshalText(data []byte) error { +func (et Type) UnmarshalText(data []byte) error { et.Type = string(data) et.Class = et.GuessClass() return nil @@ -161,11 +159,11 @@ func (et Type) MarshalText() ([]byte, error) { return []byte(et.Type), nil } -func (et Type) String() string { +func (et *Type) String() string { return et.Type } -func (et Type) Repr() string { +func (et *Type) Repr() string { return fmt.Sprintf("%s (%s)", et.Type, et.Class.Name()) } @@ -178,7 +176,6 @@ var ( StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType} StateGuestAccess = Type{"m.room.guest_access", StateEventType} StateMember = Type{"m.room.member", StateEventType} - StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType} StatePowerLevels = Type{"m.room.power_levels", StateEventType} StateRoomName = Type{"m.room.name", StateEventType} StateTopic = Type{"m.room.topic", StateEventType} @@ -195,20 +192,10 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} - StateRoomPolicy = Type{"m.room.policy", StateEventType} - StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType} - - StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} - StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} - StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} - StateUnstablePolicyRoom = Type{"org.matrix.mjolnir.rule.room", StateEventType} - StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType} - StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType} + // Deprecated: MSC2716 has been abandoned + StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} - StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} - StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} - StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType} ) // Message events @@ -237,24 +224,14 @@ var ( CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} - BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} - BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} - BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} - BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType} - BeeperSendState = Type{"com.beeper.send_state", MessageEventType} - - EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} - EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} - EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType} + BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} ) // Ephemeral events var ( - EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} - EphemeralEventTyping = Type{"m.typing", EphemeralEventType} - EphemeralEventPresence = Type{"m.presence", EphemeralEventType} - EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType} - BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType} + EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} + EphemeralEventTyping = Type{"m.typing", EphemeralEventType} + EphemeralEventPresence = Type{"m.presence", EphemeralEventType} ) // Account data events diff --git a/event/voip.go b/event/voip.go index cd8364a1..28f56c95 100644 --- a/event/voip.go +++ b/event/voip.go @@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) { type BaseCallEventContent struct { CallID string `json:"call_id"` PartyID string `json:"party_id"` - Version CallVersion `json:"version,omitempty"` + Version CallVersion `json:"version"` } type CallInviteEventContent struct { diff --git a/example/main.go b/example/main.go index 2bf4bef3..d8006d46 100644 --- a/example/main.go +++ b/example/main.go @@ -143,7 +143,7 @@ func main() { if err != nil { log.Error().Err(err).Msg("Failed to send event") } else { - log.Info().Stringer("event_id", resp.EventID).Msg("Event sent") + log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent") } } cancelSync() diff --git a/federation/cache.go b/federation/cache.go deleted file mode 100644 index 24154974..00000000 --- a/federation/cache.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package federation - -import ( - "errors" - "fmt" - "math" - "sync" - "time" -) - -// ResolutionCache is an interface for caching resolved server names. -type ResolutionCache interface { - StoreResolution(*ResolvedServerName) - // LoadResolution loads a resolved server name from the cache. - // Expired entries MUST NOT be returned. - LoadResolution(serverName string) (*ResolvedServerName, error) -} - -type KeyCache interface { - StoreKeys(*ServerKeyResponse) - StoreFetchError(serverName string, err error) - ShouldReQuery(serverName string) bool - LoadKeys(serverName string) (*ServerKeyResponse, error) -} - -type InMemoryCache struct { - MinKeyRefetchDelay time.Duration - - resolutions map[string]*ResolvedServerName - resolutionsLock sync.RWMutex - keys map[string]*ServerKeyResponse - lastReQueryAt map[string]time.Time - lastError map[string]*resolutionErrorCache - keysLock sync.RWMutex -} - -var ( - _ ResolutionCache = (*InMemoryCache)(nil) - _ KeyCache = (*InMemoryCache)(nil) -) - -func NewInMemoryCache() *InMemoryCache { - return &InMemoryCache{ - resolutions: make(map[string]*ResolvedServerName), - keys: make(map[string]*ServerKeyResponse), - lastReQueryAt: make(map[string]time.Time), - lastError: make(map[string]*resolutionErrorCache), - MinKeyRefetchDelay: 1 * time.Hour, - } -} - -func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) { - c.resolutionsLock.Lock() - defer c.resolutionsLock.Unlock() - c.resolutions[resolution.ServerName] = resolution -} - -func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) { - c.resolutionsLock.RLock() - defer c.resolutionsLock.RUnlock() - resolution, ok := c.resolutions[serverName] - if !ok || time.Until(resolution.Expires) < 0 { - return nil, nil - } - return resolution, nil -} - -func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) { - c.keysLock.Lock() - defer c.keysLock.Unlock() - c.keys[keys.ServerName] = keys - delete(c.lastError, keys.ServerName) -} - -type resolutionErrorCache struct { - Error error - Time time.Time - Count int -} - -const MaxBackoff = 7 * 24 * time.Hour - -func (rec *resolutionErrorCache) ShouldRetry() bool { - backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second - return time.Since(rec.Time) > backoff -} - -var ErrRecentKeyQueryFailed = errors.New("last retry was too recent") - -func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) { - c.keysLock.RLock() - defer c.keysLock.RUnlock() - keys, ok := c.keys[serverName] - if !ok || time.Until(keys.ValidUntilTS.Time) < 0 { - err, ok := c.lastError[serverName] - if ok && !err.ShouldRetry() { - return nil, fmt.Errorf( - "%w (%s ago) and failed with %w", - ErrRecentKeyQueryFailed, - time.Since(err.Time).String(), - err.Error, - ) - } - return nil, nil - } - return keys, nil -} - -func (c *InMemoryCache) StoreFetchError(serverName string, err error) { - c.keysLock.Lock() - defer c.keysLock.Unlock() - errorCache, ok := c.lastError[serverName] - if ok { - errorCache.Time = time.Now() - errorCache.Error = err - errorCache.Count++ - } else { - c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1} - } -} - -func (c *InMemoryCache) ShouldReQuery(serverName string) bool { - c.keysLock.Lock() - defer c.keysLock.Unlock() - lastQuery, ok := c.lastReQueryAt[serverName] - if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay { - return false - } - c.lastReQueryAt[serverName] = time.Now() - return true -} - -type noopCache struct{} - -func (*noopCache) StoreKeys(_ *ServerKeyResponse) {} -func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } -func (*noopCache) StoreFetchError(_ string, _ error) {} -func (*noopCache) ShouldReQuery(_ string) bool { return true } -func (*noopCache) StoreResolution(_ *ResolvedServerName) {} -func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } - -var ( - _ ResolutionCache = (*noopCache)(nil) - _ KeyCache = (*noopCache)(nil) -) - -var NoopCache *noopCache diff --git a/federation/client.go b/federation/client.go index 183fb5d1..098df095 100644 --- a/federation/client.go +++ b/federation/client.go @@ -9,6 +9,7 @@ package federation import ( "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -21,7 +22,6 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -30,25 +30,17 @@ type Client struct { ServerName string UserAgent string Key *SigningKey - - ResponseSizeLimit int64 } -func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { +func NewClient(serverName string, key *SigningKey) *Client { return &Client{ HTTP: &http.Client{ - Transport: NewServerResolvingTransport(cache), + Transport: NewServerResolvingTransport(), Timeout: 120 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Federation requests do not allow redirects. - return http.ErrUseLastResponse - }, }, UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, Key: key, - - ResponseSizeLimit: mautrix.DefaultResponseSizeLimit, } } @@ -62,7 +54,7 @@ func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *Serve return } -func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) { +func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) { err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp) return } @@ -89,7 +81,7 @@ type RespSendTransaction struct { } func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { - err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp) + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) return } @@ -228,26 +220,6 @@ func (c *Client) Query(ctx context.Context, serverName, queryType string, queryP return } -func queryToValues(query map[string]string) url.Values { - values := make(url.Values, len(query)) - for k, v := range query { - values[k] = []string{v} - } - return values -} - -func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: serverName, - Method: http.MethodGet, - Path: URLPath{"v1", "publicRooms"}, - Query: queryToValues(req.Query()), - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - type RespOpenIDUserInfo struct { Sub id.UserID `json:"sub"` } @@ -263,169 +235,6 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken return } -type ReqMakeJoin struct { - RoomID id.RoomID - UserID id.UserID - Via string - SupportedVersions []id.RoomVersion -} - -type RespMakeJoin struct { - RoomVersion id.RoomVersion `json:"room_version"` - Event PDU `json:"event"` -} - -type ReqSendJoin struct { - RoomID id.RoomID - EventID id.EventID - OmitMembers bool - Event PDU - Via string -} - -type ReqSendKnock struct { - RoomID id.RoomID - EventID id.EventID - Event PDU - Via string -} - -type RespSendJoin struct { - AuthChain []PDU `json:"auth_chain"` - Event PDU `json:"event"` - MembersOmitted bool `json:"members_omitted"` - ServersInRoom []string `json:"servers_in_room"` - State []PDU `json:"state"` -} - -type RespSendKnock struct { - KnockRoomState []PDU `json:"knock_room_state"` -} - -type ReqSendInvite struct { - RoomID id.RoomID `json:"-"` - UserID id.UserID `json:"-"` - Event PDU `json:"event"` - InviteRoomState []PDU `json:"invite_room_state"` - RoomVersion id.RoomVersion `json:"room_version"` -} - -type RespSendInvite struct { - Event PDU `json:"event"` -} - -type ReqMakeLeave struct { - RoomID id.RoomID - UserID id.UserID - Via string -} - -type ReqSendLeave struct { - RoomID id.RoomID - EventID id.EventID - Event PDU - Via string -} - -type ( - ReqMakeKnock = ReqMakeJoin - RespMakeKnock = RespMakeJoin - RespMakeLeave = RespMakeJoin -) - -func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) { - versions := make([]string, len(req.SupportedVersions)) - for i, v := range req.SupportedVersions { - versions[i] = string(v) - } - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodGet, - Path: URLPath{"v1", "make_join", req.RoomID, req.UserID}, - Query: url.Values{"ver": versions}, - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) { - versions := make([]string, len(req.SupportedVersions)) - for i, v := range req.SupportedVersions { - versions[i] = string(v) - } - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodGet, - Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID}, - Query: url.Values{"ver": versions}, - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodPut, - Path: URLPath{"v2", "send_join", req.RoomID, req.EventID}, - Query: url.Values{ - "omit_members": {strconv.FormatBool(req.OmitMembers)}, - }, - Authenticate: true, - RequestJSON: req.Event, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodPut, - Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID}, - Authenticate: true, - RequestJSON: req.Event, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.UserID.Homeserver(), - Method: http.MethodPut, - Path: URLPath{"v2", "invite", req.RoomID, req.UserID}, - Authenticate: true, - RequestJSON: req, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodGet, - Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID}, - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodPut, - Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID}, - Authenticate: true, - RequestJSON: req.Event, - }) - return -} - type URLPath []any func (fup URLPath) FullPath() []any { @@ -477,27 +286,15 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } - if !params.DontReadBody { - defer resp.Body.Close() - } + defer func() { + _ = resp.Body.Close() + }() var body []byte - if resp.StatusCode >= 300 { + if resp.StatusCode >= 400 { body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { - if resp.ContentLength > c.ResponseSizeLimit { - return body, resp, mautrix.HTTPError{ - Request: req, - Response: resp, - - Message: "not reading response", - WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024), - } - } - body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) - if err == nil && len(body) > int(c.ResponseSizeLimit) { - err = mautrix.ErrBodyReadReachedLimit - } + body, err = io.ReadAll(resp.Body) if err != nil { return body, resp, mautrix.HTTPError{ Request: req, @@ -557,12 +354,16 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt Message: "client not configured for authentication", } } + var contentAny any + if reqJSON != nil { + contentAny = reqJSON + } auth, err := (&signableRequest{ Method: req.Method, URI: reqURL.RequestURI(), Origin: c.ServerName, Destination: params.ServerName, - Content: reqJSON, + Content: contentAny, }).Sign(c.Key) if err != nil { return nil, mautrix.HTTPError{ @@ -576,19 +377,11 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt } type signableRequest struct { - Method string `json:"method"` - URI string `json:"uri"` - Origin string `json:"origin"` - Destination string `json:"destination"` - Content json.RawMessage `json:"content,omitempty"` -} - -func (r *signableRequest) Verify(key id.SigningKey, sig string) error { - message, err := json.Marshal(r) - if err != nil { - return fmt.Errorf("failed to marshal data: %w", err) - } - return signutil.VerifyJSONRaw(key, sig, message) + Method string `json:"method"` + URI string `json:"uri"` + Origin string `json:"origin"` + Destination string `json:"destination"` + Content any `json:"content,omitempty"` } func (r *signableRequest) Sign(key *SigningKey) (string, error) { @@ -596,10 +389,11 @@ func (r *signableRequest) Sign(key *SigningKey) (string, error) { if err != nil { return "", err } - return XMatrixAuth{ - Origin: r.Origin, - Destination: r.Destination, - KeyID: key.ID, - Signature: sig, - }.String(), nil + return fmt.Sprintf( + `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, + r.Origin, + r.Destination, + key.ID, + base64.RawURLEncoding.EncodeToString(sig), + ), nil } diff --git a/federation/client_test.go b/federation/client_test.go index ece399ea..ba3c3ed4 100644 --- a/federation/client_test.go +++ b/federation/client_test.go @@ -16,7 +16,7 @@ import ( ) func TestClient_Version(t *testing.T) { - cli := federation.NewClient("", nil, nil) + cli := federation.NewClient("", nil) resp, err := cli.Version(context.TODO(), "maunium.net") require.NoError(t, err) require.Equal(t, "Synapse", resp.Server.Name) diff --git a/federation/context.go b/federation/context.go deleted file mode 100644 index eedb2dc1..00000000 --- a/federation/context.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package federation - -import ( - "context" - "net/http" -) - -type contextKey int - -const ( - contextKeyIPPort contextKey = iota - contextKeyDestinationServer - contextKeyOriginServer -) - -func DestinationServerNameFromRequest(r *http.Request) string { - return DestinationServerName(r.Context()) -} - -func DestinationServerName(ctx context.Context) string { - if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok { - return dest - } - return "" -} - -func OriginServerNameFromRequest(r *http.Request) string { - return OriginServerName(r.Context()) -} - -func OriginServerName(ctx context.Context) string { - if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok { - return origin - } - return "" -} diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go deleted file mode 100644 index c72933c2..00000000 --- a/federation/eventauth/eventauth.go +++ /dev/null @@ -1,851 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package eventauth - -import ( - "encoding/json" - "encoding/json/jsontext" - "errors" - "fmt" - "slices" - "strconv" - "strings" - - "github.com/tidwall/gjson" - "go.mau.fi/util/exgjson" - "go.mau.fi/util/exstrings" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/federation/signutil" - "maunium.net/go/mautrix/id" -) - -type AuthFailError struct { - Index string - Message string - Wrapped error -} - -func (afe AuthFailError) Error() string { - if afe.Message != "" { - return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message) - } else if afe.Wrapped != nil { - return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error()) - } - return fmt.Sprintf("fail %s", afe.Index) -} - -func (afe AuthFailError) Unwrap() error { - return afe.Wrapped -} - -var mFederatePath = exgjson.Path("m.federate") - -var ( - ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"} - ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"} - ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"} - ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion} - ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"} - ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"} - - ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"} - ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"} - ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"} - ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"} - - ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"} - ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"} - ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"} - ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"} - ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"} - ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"} - ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"} - ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"} - ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"} - - ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"} - - ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"} - ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"} - ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"} - ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"} - ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"} - ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"} - ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"} - ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"} - ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"} - ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"} - ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"} - ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"} - ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"} - ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"} - ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"} - ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"} - ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"} - ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"} - ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"} - ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"} - ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"} - ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"} - ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"} - ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"} - ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"} - ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"} - ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"} - ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"} - - ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"} - - ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"} - - ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"} - - ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"} - - ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} - ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} - ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} - ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} - ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} - ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} - ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"} -) - -func isRejected(evt *pdu.PDU) bool { - return evt.InternalMeta.Rejected -} - -type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error) - -func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error { - if evt.Type == event.StateCreate.Type { - // 1. If type is m.room.create: - return authorizeCreate(roomVersion, evt) - } - var createEvt *pdu.PDU - if roomVersion.RoomIDIsCreateEventID() { - // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event, - // with the sigil ! instead of $, reject. - if len(evt.RoomID) != 44 { - return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID)) - } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil { - return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err) - } else if len(createEvts) != 1 { - return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID) - } else if isRejected(createEvts[0]) { - return ErrRejectedCreateEvent - } else { - createEvt = createEvts[0] - } - } - authEvents, err := getEvents(evt.AuthEvents) - if err != nil { - return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err) - } - expectedAuthEvents := evt.AuthEventSelection(roomVersion) - deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents)) - // 3. Considering the event’s auth_events: - for i, ae := range authEvents { - authEvtID := evt.AuthEvents[i] - if ae == nil { - return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID) - } else if ae.StateKey == nil { - // This approximately falls under rule 3.2. - return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID) - } - key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey} - if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound { - // 3.1. If there are duplicate entries for a given type and state_key pair, reject. - return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID) - } else if !expectedAuthEvents.Has(key) { - // 3.2. If there are entries whose type and state_key don’t match those specified by - // the auth events selection algorithm described in the server specification, reject. - return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey) - } else if isRejected(ae) { - // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. - return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID) - } else if ae.RoomID != evt.RoomID { - // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject. - return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID) - } else { - deduplicator[key] = authEvtID - } - if ae.Type == event.StateCreate.Type { - if createEvt == nil { - createEvt = ae - } else { - // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+ - panic(fmt.Errorf("impossible case: multiple create events found in auth events")) - } - } - } - if createEvt == nil { - // This comes either from auth_events or room_id depending on the room version. - // The checks above make sure it's from the right source. - return ErrNoCreateEvent - } - if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() { - // 4. If the content of the m.room.create event in the room state has the property m.federate set to false, - // and the sender domain of the event does not match the sender domain of the create event, reject. - return ErrFederationDisabled - } - if evt.Type == event.StateMember.Type { - // 5. If type is m.room.member: - return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey) - } - senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) - if senderMembership != event.MembershipJoin { - // 6. If the sender’s current membership state is not join, reject. - return ErrNotInRoom - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - senderPL := powerLevels.GetUserLevel(evt.Sender) - if evt.Type == event.StateThirdPartyInvite.Type { - // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level. - if senderPL >= powerLevels.Invite() { - return nil - } - return ErrInsufficientPowerForThirdPartyInvite - } - typeClass := event.MessageEventType - if evt.StateKey != nil { - typeClass = event.StateEventType - } - evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass}) - if evtLevel > senderPL { - // 8. If the event type’s required power level is greater than the sender’s power level, reject. - return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL) - } - - if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() { - // 9. If the event has a state_key that starts with an @ and does not match the sender, reject. - return ErrMismatchingPrivateStateKey - } - - if evt.Type == event.StatePowerLevels.Type { - // 10. If type is m.room.power_levels: - return authorizePowerLevels(roomVersion, evt, createEvt, authEvents) - } - - // 11. Otherwise, allow. - return nil -} - -var ErrUserIDNotAString = errors.New("not a string") -var ErrUserIDNotValid = errors.New("not a valid user ID") - -func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error { - if userID.Type != gjson.String { - return ErrUserIDNotAString - } - // In a future room version, user IDs will have stricter validation - _, _, err := id.UserID(userID.Str).Parse() - if err != nil { - return ErrUserIDNotValid - } - return nil -} - -func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error { - if len(evt.PrevEvents) > 0 { - // 1.1. If it has any prev_events, reject. - return ErrCreateHasPrevEvents - } - if roomVersion.RoomIDIsCreateEventID() { - if evt.RoomID != "" { - // 1.2. If the event has a room_id, reject. - return ErrCreateHasRoomID - } - } else { - _, _, server := id.ParseCommonIdentifier(evt.RoomID) - if server == "" || server != evt.Sender.Homeserver() { - // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject. - return ErrRoomIDDoesntMatchSender - } - } - if !roomVersion.IsKnown() { - // 1.3. If content.room_version is present and is not a recognised version, reject. - return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion) - } - if roomVersion.PrivilegedRoomCreators() { - additionalCreators := gjson.GetBytes(evt.Content, "additional_creators") - if additionalCreators.Exists() { - if !additionalCreators.IsArray() { - return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators) - } - for i, item := range additionalCreators.Array() { - // 1.4. If additional_creators is present in content and is not an array of strings - // where each string passes the same user ID validation applied to sender, reject. - if err := isValidUserID(roomVersion, item); err != nil { - return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err) - } - } - } - } - if roomVersion.CreatorInContent() { - // 1.4. (v10 and below) If content has no creator property, reject. - if !gjson.GetBytes(evt.Content, "creator").Exists() { - return ErrMissingCreator - } - } - // 1.5. Otherwise, allow. - return nil -} - -func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error { - membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str) - if evt.StateKey == nil { - // 5.1. If there is no state_key property, or no membership property in content, reject. - return ErrMemberNotState - } - authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str) - if authorizedVia != "" { - homeserver := authorizedVia.Homeserver() - err := evt.VerifySignature(roomVersion, homeserver, getKey) - if err != nil { - // 5.2. If content has a join_authorised_via_users_server key: - // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject. - return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err) - } - } - targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave")) - senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) - switch membership { - case event.MembershipJoin: - createEvtID, err := createEvt.GetEventID(roomVersion) - if err != nil { - return fmt.Errorf("failed to get create event ID: %w", err) - } - creator := createEvt.Sender.String() - if roomVersion.CreatorInContent() { - creator = gjson.GetBytes(evt.Content, "creator").Str - } - if len(evt.PrevEvents) == 1 && - len(evt.AuthEvents) <= 1 && - evt.PrevEvents[0] == createEvtID && - *evt.StateKey == creator { - // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow. - return nil - } - // Spec wart: this would make more sense before the check above. - // Now you can set anyone as the sender of the first join. - if evt.Sender.String() != *evt.StateKey { - // 5.3.2. If the sender does not match state_key, reject. - return ErrCantJoinOtherUser - } - - if senderMembership == event.MembershipBan { - // 5.3.3. If the sender is banned, reject. - return ErrCantJoinBanned - } - - joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) - switch joinRule { - case event.JoinRuleKnock: - if !roomVersion.Knocks() { - return ErrInvalidJoinRule - } - fallthrough - case event.JoinRuleInvite: - // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join. - if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { - return nil - } - return ErrCantJoinWithoutInvite - case event.JoinRuleKnockRestricted: - if !roomVersion.KnockRestricted() { - return ErrInvalidJoinRule - } - fallthrough - case event.JoinRuleRestricted: - if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() { - return ErrInvalidJoinRule - } - if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { - // 5.3.5.1. If membership state is join or invite, allow. - return nil - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() { - // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject. - return ErrAuthoriserCantInvite - } - authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave))) - if authorizerMembership != event.MembershipJoin { - return ErrAuthoriserNotInRoom - } - // 5.3.5.3. Otherwise, allow. - return nil - case event.JoinRulePublic: - // 5.3.6. If the join_rule is public, allow. - return nil - default: - // 5.3.7. Otherwise, reject. - return ErrInvalidJoinRule - } - case event.MembershipInvite: - tpiVal := gjson.GetBytes(evt.Content, "third_party_invite") - if tpiVal.Exists() { - if targetPrevMembership == event.MembershipBan { - return ErrThirdPartyInviteBanned - } - signed := tpiVal.Get("signed") - mxid := signed.Get("mxid").Str - token := signed.Get("token").Str - if mxid == "" || token == "" { - // 5.4.1.2. If content.third_party_invite does not have a signed property, reject. - // 5.4.1.3. If signed does not have mxid and token properties, reject. - return ErrThirdPartyInviteMissingFields - } - if mxid != *evt.StateKey { - // 5.4.1.4. If mxid does not match state_key, reject. - return ErrThirdPartyInviteMXIDMismatch - } - tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token) - if tpiEvt == nil { - // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject. - return ErrThirdPartyInviteNotFound - } - if tpiEvt.Sender != evt.Sender { - // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject. - return ErrThirdPartyInviteSenderMismatch - } - var keys []id.Ed25519 - const ed25519Base64Len = 43 - oldPubKey := gjson.GetBytes(evt.Content, "public_key.token") - if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len { - keys = append(keys, id.Ed25519(oldPubKey.Str)) - } - gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool { - if key.Type != gjson.Number { - return false - } - if value.Type == gjson.String && len(value.Str) == ed25519Base64Len { - keys = append(keys, id.Ed25519(value.Str)) - } - return true - }) - rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str)) - var validated bool - for _, key := range keys { - if signutil.VerifyJSONAny(key, rawSigned) == nil { - validated = true - } - } - if validated { - // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow. - return nil - } - // 4.4.1.8. Otherwise, reject. - return ErrThirdPartyInviteNotSigned - } - if senderMembership != event.MembershipJoin { - // 5.4.2. If the sender’s current membership state is not join, reject. - return ErrInviterNotInRoom - } - // 5.4.3. If target user’s current membership state is join or ban, reject. - if targetPrevMembership == event.MembershipJoin { - return ErrInviteTargetAlreadyInRoom - } else if targetPrevMembership == event.MembershipBan { - return ErrInviteTargetBanned - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() { - // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow. - return nil - } - // 5.4.5. Otherwise, reject. - return ErrInsufficientPermissionForInvite - case event.MembershipLeave: - if evt.Sender.String() == *evt.StateKey { - // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock. - if senderMembership == event.MembershipInvite || - senderMembership == event.MembershipJoin || - (senderMembership == event.MembershipKnock && roomVersion.Knocks()) { - return nil - } - return ErrCantLeaveWithoutBeingInRoom - } - if senderMembership != event.MembershipJoin { - // 5.5.2. If the sender’s current membership state is not join, reject. - return ErrCantKickWithoutBeingInRoom - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - senderLevel := powerLevels.GetUserLevel(evt.Sender) - if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() { - // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject. - return ErrInsufficientPermissionForUnban - } - if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { - // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow. - return nil - } - // TODO separate errors for < kick and < target user level? - // 5.5.5. Otherwise, reject. - return ErrInsufficientPermissionForKick - case event.MembershipBan: - if senderMembership != event.MembershipJoin { - // 5.6.1. If the sender’s current membership state is not join, reject. - return ErrCantBanWithoutBeingInRoom - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - senderLevel := powerLevels.GetUserLevel(evt.Sender) - if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { - // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow. - return nil - } - // 5.6.3. Otherwise, reject. - return ErrInsufficientPermissionForBan - case event.MembershipKnock: - joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) - validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock - validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted - if !validKnockRule && !validKnockRestrictedRule { - // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject. - return ErrNotKnockableRoom - } - if evt.Sender.String() != *evt.StateKey { - // 5.7.2. If the sender does not match state_key, reject. - return ErrCantKnockOtherUser - } - if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin { - // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow. - return nil - } - // 5.7.4. Otherwise, reject. - return ErrCantKnockWhileInRoom - default: - // 5.8. Otherwise, the membership is unknown. Reject. - return ErrUnknownMembership - } -} - -func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error { - if roomVersion.ValidatePowerLevelInts() { - for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { - res := gjson.GetBytes(evt.Content, key) - if !res.Exists() { - continue - } - if parseIntWithVersion(roomVersion, res) == nil { - // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject. - return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) - } - } - for _, key := range []string{"events", "notifications"} { - obj := gjson.GetBytes(evt.Content, key) - if !obj.Exists() { - continue - } - // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject. - if !obj.IsObject() { - return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) - } - var err error - // 10.2. [...] are not an object with values that are integers, reject. - obj.ForEach(func(innerKey, value gjson.Result) bool { - if parseIntWithVersion(roomVersion, value) == nil { - err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str) - return false - } - return true - }) - if err != nil { - return err - } - } - } - var creators []id.UserID - if roomVersion.PrivilegedRoomCreators() { - creators = append(creators, createEvt.Sender) - gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool { - creators = append(creators, id.UserID(value.Str)) - return true - }) - } - users := gjson.GetBytes(evt.Content, "users") - if users.Exists() { - if !users.IsObject() { - // 10.3. If the users property in content is not an object [...], reject. - return fmt.Errorf("%w users", ErrTopLevelPLNotInteger) - } - var err error - users.ForEach(func(key, value gjson.Result) bool { - if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil { - // 10.3. [...] is not an object with keys that are valid user IDs [...], reject. - err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr) - return false - } - if parseIntWithVersion(roomVersion, value) == nil { - // 10.3. [...] is not an object [...] with values that are integers, reject. - err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str) - return false - } - // creators is only filled if the room version has privileged room creators - if slices.Contains(creators, id.UserID(key.Str)) { - // 10.4. If the users property in content contains the sender of the m.room.create event or any of - // the additional_creators array (if present) from the content of the m.room.create event, reject. - err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str) - return false - } - return true - }) - if err != nil { - return err - } - } - oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "") - if oldPL == nil { - // 10.5. If there is no previous m.room.power_levels event in the room, allow. - return nil - } - if slices.Contains(creators, evt.Sender) { - // Skip remaining checks for creators - return nil - } - senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String()))) - if senderPLPtr == nil { - senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default")) - if senderPLPtr == nil { - senderPLPtr = ptr.Ptr(0) - } - } - for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { - oldVal := gjson.GetBytes(oldPL.Content, key) - newVal := gjson.GetBytes(evt.Content, key) - if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil { - return err - } - } - if err := allowPowerChangeMap( - roomVersion, *senderPLPtr, "events", "", - gjson.GetBytes(oldPL.Content, "events"), - gjson.GetBytes(evt.Content, "events"), - ); err != nil { - return err - } - if err := allowPowerChangeMap( - roomVersion, *senderPLPtr, "notifications", "", - gjson.GetBytes(oldPL.Content, "notifications"), - gjson.GetBytes(evt.Content, "notifications"), - ); err != nil { - return err - } - if err := allowPowerChangeMap( - roomVersion, *senderPLPtr, "users", evt.Sender.String(), - gjson.GetBytes(oldPL.Content, "users"), - gjson.GetBytes(evt.Content, "users"), - ); err != nil { - return err - } - return nil -} - -func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) { - old.ForEach(func(key, value gjson.Result) bool { - newVal := new.Get(exgjson.Path(key.Str)) - err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal) - if err == nil && ownID != "" && key.Str != ownID { - parsedOldVal := parseIntWithVersion(roomVersion, value) - parsedNewVal := parseIntWithVersion(roomVersion, newVal) - if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal { - err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) - } - } - return err == nil - }) - if err != nil { - return - } - new.ForEach(func(key, value gjson.Result) bool { - err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value) - return err == nil - }) - return -} - -func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error { - oldVal := parseIntWithVersion(roomVersion, old) - newVal := parseIntWithVersion(roomVersion, new) - if oldVal == nil { - if newVal == nil || *newVal <= maxVal { - return nil - } - } else if newVal == nil { - if *oldVal <= maxVal { - return nil - } - } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) { - return nil - } - return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal) -} - -func stringifyForError(val gjson.Result) string { - if !val.Exists() { - return "null" - } - return val.Raw -} - -func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU { - for _, evt := range events { - if evt.Type == evtType && *evt.StateKey == stateKey { - return evt - } - } - return nil -} - -func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T { - return reader(findEvent(events, evtType, stateKey)) -} - -func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string { - return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string { - if evt == nil { - return defVal - } - res := gjson.GetBytes(evt.Content, fieldPath) - if res.Type != gjson.String { - return defVal - } - return res.Str - }) -} - -func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) { - var err error - powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent { - if evt == nil { - return nil - } - content := evt.Content - out := &event.PowerLevelsEventContent{} - if !roomVersion.ValidatePowerLevelInts() { - safeParsePowerLevels(content, out) - } else { - err = json.Unmarshal(content, out) - } - return out - }) - if err != nil { - // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+ - return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) - } - if roomVersion.PrivilegedRoomCreators() { - if powerLevels == nil { - powerLevels = &event.PowerLevelsEventContent{} - } - powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) - } - err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) - } - } else if powerLevels == nil { - powerLevels = &event.PowerLevelsEventContent{ - Users: map[id.UserID]int{ - createEvt.Sender: 100, - }, - } - } - return powerLevels, nil -} - -func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int { - if roomVersion.ValidatePowerLevelInts() { - if val.Type != gjson.Number { - return nil - } - return ptr.Ptr(int(val.Int())) - } - return parsePythonInt(val) -} - -func parsePythonInt(val gjson.Result) *int { - switch val.Type { - case gjson.True: - return ptr.Ptr(1) - case gjson.False: - return ptr.Ptr(0) - case gjson.Number: - return ptr.Ptr(int(val.Int())) - case gjson.String: - // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand - num, err := strconv.Atoi(strings.TrimSpace(val.Str)) - if err != nil { - return nil - } - return &num - default: - // Python int() doesn't accept nulls, arrays or dicts - return nil - } -} - -func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { - *into = event.PowerLevelsEventContent{ - Users: make(map[id.UserID]int), - UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))), - Events: make(map[string]int), - EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))), - Notifications: nil, // irrelevant for event auth - StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")), - InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")), - KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")), - BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")), - RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")), - } - gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool { - if key.Type != gjson.String { - return false - } - val := parsePythonInt(value) - if val != nil { - into.Events[key.Str] = *val - } - return true - }) - gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool { - if key.Type != gjson.String { - return false - } - val := parsePythonInt(value) - if val == nil { - return false - } - userID := id.UserID(key.Str) - if _, _, err := userID.Parse(); err != nil { - return false - } - into.Users[userID] = *val - return true - }) -} diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go deleted file mode 100644 index d316f3c8..00000000 --- a/federation/eventauth/eventauth_internal_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package eventauth - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" -) - -type pythonIntTest struct { - Name string - Input string - Expected int64 -} - -var pythonIntTests = []pythonIntTest{ - {"True", `true`, 1}, - {"False", `false`, 0}, - {"SmallFloat", `3.1415`, 3}, - {"SmallFloatRoundDown", `10.999999999999999`, 10}, - {"SmallFloatRoundUp", `10.9999999999999999`, 11}, - {"BigFloatRoundDown", `1000000.9999999999`, 1000000}, - {"BigFloatRoundUp", `1000000.99999999999`, 1000001}, - {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992}, - {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994}, - {"Int64", `9223372036854775807`, 9223372036854775807}, - {"Int64String", `"9223372036854775807"`, 9223372036854775807}, - {"String", `"123"`, 123}, - {"InvalidFloatInString", `"123.456"`, 0}, - {"StringWithPlusSign", `"+123"`, 123}, - {"StringWithMinusSign", `"-123"`, -123}, - {"StringWithSpaces", `" 123 "`, 123}, - {"StringWithSpacesAndSign", `" -123 "`, -123}, - //{"StringWithUnderscores", `"123_456"`, 123456}, - //{"StringWithUnderscores", `"123_456"`, 123456}, - {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, - {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0}, - {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, - {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, - {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, - //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, -} - -func TestParsePythonInt(t *testing.T) { - for _, test := range pythonIntTests { - t.Run(test.Name, func(t *testing.T) { - output := parsePythonInt(gjson.Parse(test.Input)) - if strings.HasPrefix(test.Name, "Invalid") { - assert.Nil(t, output) - } else { - require.NotNil(t, output) - assert.Equal(t, int(test.Expected), *output) - } - }) - } -} diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go deleted file mode 100644 index e3c5cd76..00000000 --- a/federation/eventauth/eventauth_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package eventauth_test - -import ( - "embed" - "encoding/json/jsontext" - "encoding/json/v2" - "errors" - "io" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/federation/eventauth" - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -//go:embed *.jsonl -var data embed.FS - -type eventMap map[id.EventID]*pdu.PDU - -func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) { - output := make([]*pdu.PDU, len(ids)) - for i, evtID := range ids { - output[i] = em[evtID] - } - return output, nil -} - -func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) { - return "", time.Time{}, nil -} - -func TestAuthorize(t *testing.T) { - files := exerrors.Must(data.ReadDir(".")) - for _, file := range files { - t.Run(file.Name(), func(t *testing.T) { - decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name()))) - events := make(eventMap) - var roomVersion *id.RoomVersion - for i := 1; ; i++ { - var evt *pdu.PDU - err := json.UnmarshalDecode(decoder, &evt) - if errors.Is(err, io.EOF) { - break - } - require.NoError(t, err) - if roomVersion == nil { - require.Equal(t, evt.Type, "m.room.create") - roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str)) - } - expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str - evtID, err := evt.GetEventID(*roomVersion) - require.NoError(t, err) - require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i) - - // TODO allow redacted events - assert.True(t, evt.VerifyContentHash(), i) - - events[evtID] = evt - err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey) - if err != nil { - evt.InternalMeta.Rejected = true - } - // TODO allow testing intentionally rejected events - assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type) - } - }) - } - -} diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl deleted file mode 100644 index 2b751de3..00000000 --- a/federation/eventauth/testroom-v12-success.jsonl +++ /dev/null @@ -1,21 +0,0 @@ -{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}} -{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}} -{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}} -{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}} -{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} -{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}} -{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}} -{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}} -{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}} -{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}} diff --git a/federation/httpclient.go b/federation/httpclient.go index 2f8dbb4f..d6d97280 100644 --- a/federation/httpclient.go +++ b/federation/httpclient.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "sync" + "time" ) // ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. @@ -21,20 +22,17 @@ type ServerResolvingTransport struct { Transport *http.Transport Dialer *net.Dialer - cache ResolutionCache - - resolveLocks map[string]*sync.Mutex - resolveLocksLock sync.Mutex + cache map[string]*ResolvedServerName + resolveLocks map[string]*sync.Mutex + cacheLock sync.Mutex } -func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport { - if cache == nil { - cache = NewInMemoryCache() - } +func NewServerResolvingTransport() *ServerResolvingTransport { srt := &ServerResolvingTransport{ + cache: make(map[string]*ResolvedServerName), resolveLocks: make(map[string]*sync.Mutex), - cache: cache, - Dialer: &net.Dialer{}, + + Dialer: &net.Dialer{}, } srt.Transport = &http.Transport{ DialContext: srt.DialContext, @@ -52,6 +50,12 @@ func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, a return srt.Dialer.DialContext(ctx, network, addrs[0]) } +type contextKey int + +const ( + contextKeyIPPort contextKey = iota +) + func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) { if request.URL.Scheme != "matrix-federation" { return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme) @@ -68,25 +72,37 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res } func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { - srt.resolveLocksLock.Lock() - lock, ok := srt.resolveLocks[serverName] - if !ok { - lock = &sync.Mutex{} - srt.resolveLocks[serverName] = lock + res, lock := srt.getResolveCache(serverName) + if res != nil { + return res, nil } - srt.resolveLocksLock.Unlock() - lock.Lock() defer lock.Unlock() - res, err := srt.cache.LoadResolution(serverName) - if err != nil { - return nil, fmt.Errorf("failed to read cache: %w", err) - } else if res != nil { - return res, nil - } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil { - return nil, err - } else { - srt.cache.StoreResolution(res) + res, _ = srt.getResolveCache(serverName) + if res != nil { return res, nil } + var err error + res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts) + if err != nil { + return nil, err + } + srt.cacheLock.Lock() + srt.cache[serverName] = res + srt.cacheLock.Unlock() + return res, nil +} + +func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) { + srt.cacheLock.Lock() + defer srt.cacheLock.Unlock() + if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 { + return val, nil + } + rl, ok := srt.resolveLocks[serverName] + if !ok { + rl = &sync.Mutex{} + srt.resolveLocks[serverName] = rl + } + return nil, rl } diff --git a/federation/keyserver.go b/federation/keyserver.go index d32ba5cf..3e74bfdf 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,17 +8,13 @@ package federation import ( "encoding/json" + "fmt" "net/http" "strconv" "time" - "github.com/rs/zerolog" - "github.com/rs/zerolog/hlog" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exhttp" + "github.com/gorilla/mux" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" - "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" @@ -51,29 +47,34 @@ type KeyServer struct { KeyProvider ServerKeyProvider Version ServerVersion WellKnownTarget string - OtherKeys KeyCache } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) { - r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) - r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) - keyRouter := http.NewServeMux() - keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey) - keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) - keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) - errorBodies := exhttp.ErrorBodies{ - NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), - MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), - } - r.Handle("/_matrix/key/", exhttp.ApplyMiddleware( - keyRouter, - exhttp.StripPrefix("/_matrix/key"), - hlog.NewHandler(log), - hlog.RequestIDHandler("request_id", "Request-Id"), - requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), - exhttp.HandleErrors(errorBodies), - )) +func (ks *KeyServer) Register(r *mux.Router) { + r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) + r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) + keyRouter := r.PathPrefix("/_matrix/key").Subrouter() + keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) + keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) + keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) + keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Unrecognized endpoint", + }) + }) + keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Invalid method for endpoint", + }) + }) +} + +func jsonResponse(w http.ResponseWriter, code int, data any) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(data) } // RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. @@ -86,9 +87,12 @@ type RespWellKnown struct { // https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) { if ks.WellKnownTarget == "" { - mautrix.MNotFound.WithMessage("No well-known target set").Write(w) + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "No well-known target set", + }) } else { - exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) + jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) } } @@ -101,7 +105,7 @@ type RespServerVersion struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { - exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) + jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) } // GetServerKey implements the `GET /_matrix/key/v2/server` endpoint. @@ -110,9 +114,12 @@ func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) { domain, key := ks.KeyProvider.Get(r) if key == nil { - mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w) + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("No signing key found for %q", r.Host), + }) } else { - exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) + jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) } } @@ -137,7 +144,10 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { var req ReqQueryKeys err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w) + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + ErrCode: mautrix.MBadJSON.ErrCode, + Err: fmt.Sprintf("failed to parse request: %v", err), + }) return } @@ -155,7 +165,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { } } } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + jsonResponse(w, http.StatusOK, resp) } // GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint @@ -167,39 +177,27 @@ type GetQueryKeysResponse struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { - serverName := r.PathValue("serverName") + serverName := mux.Vars(r)["serverName"] minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { - mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w) + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + ErrCode: mautrix.MInvalidParam.ErrCode, + Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err), + }) return } else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) { - mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w) + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + ErrCode: mautrix.MInvalidParam.ErrCode, + Err: "minimum_valid_until_ts may not be more than 24 hours in the future", + }) return } resp := &GetQueryKeysResponse{ ServerKeys: []*ServerKeyResponse{}, } - domain, key := ks.KeyProvider.Get(r) - if domain == serverName { - if key != nil { - resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) - } - } else if ks.OtherKeys != nil { - otherKey, err := ks.OtherKeys.LoadKeys(serverName) - if err != nil { - mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w) - return - } - if key != nil && domain != "" { - signature, err := key.SignJSON(otherKey) - if err == nil { - otherKey.Signatures[domain] = map[id.KeyID]string{ - key.ID: signature, - } - } - } - resp.ServerKeys = append(resp.ServerKeys, otherKey) + if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName { + resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + jsonResponse(w, http.StatusOK, resp) } diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go deleted file mode 100644 index 16706fe5..00000000 --- a/federation/pdu/auth.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "slices" - - "github.com/tidwall/gjson" - "go.mau.fi/util/exgjson" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type StateKey struct { - Type string - StateKey string -} - -var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token") - -type AuthEventSelection []StateKey - -func (aes *AuthEventSelection) Add(evtType, stateKey string) { - key := StateKey{Type: evtType, StateKey: stateKey} - if !aes.Has(key) { - *aes = append(*aes, key) - } -} - -func (aes *AuthEventSelection) Has(key StateKey) bool { - return slices.Contains(*aes, key) -} - -func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) { - if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { - return AuthEventSelection{} - } - keys = make(AuthEventSelection, 0, 3) - if !roomVersion.RoomIDIsCreateEventID() { - keys.Add(event.StateCreate.Type, "") - } - keys.Add(event.StatePowerLevels.Type, "") - keys.Add(event.StateMember.Type, pdu.Sender.String()) - if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { - keys.Add(event.StateMember.Type, *pdu.StateKey) - membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) - if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { - keys.Add(event.StateJoinRules.Type, "") - } - if membership == event.MembershipInvite { - thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str - if thirdPartyInviteToken != "" { - keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) - } - } - if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { - authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str - if authorizedVia != "" { - keys.Add(event.StateMember.Type, authorizedVia) - } - } - } - return -} diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go deleted file mode 100644 index 38ef83e9..00000000 --- a/federation/pdu/hash.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "fmt" - - "github.com/tidwall/gjson" - - "maunium.net/go/mautrix/id" -) - -func (pdu *PDU) CalculateContentHash() ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - pduClone := pdu.Clone() - pduClone.Signatures = nil - pduClone.Unsigned = nil - pduClone.Hashes = nil - rawJSON, err := marshalCanonical(pduClone) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *PDU) FillContentHash() error { - if pdu == nil { - return ErrPDUIsNil - } else if pdu.Hashes != nil { - return nil - } else if hash, err := pdu.CalculateContentHash(); err != nil { - return err - } else { - pdu.Hashes = &Hashes{SHA256: hash[:]} - return nil - } -} - -func (pdu *PDU) VerifyContentHash() bool { - if pdu == nil || pdu.Hashes == nil { - return false - } - calculatedHash, err := pdu.CalculateContentHash() - if err != nil { - return false - } - return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) -} - -func (pdu *PDU) GetRoomID() (id.RoomID, error) { - if pdu == nil { - return "", ErrPDUIsNil - } else if pdu.Type != "m.room.create" { - return "", fmt.Errorf("room ID can only be calculated for m.room.create events") - } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { - return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) - } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { - return "", fmt.Errorf("failed to calculate event ID: %w", err) - } else { - return id.RoomID(evtID), nil - } -} - -var UseInternalMetaForGetEventID = false - -func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { - if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" { - return pdu.InternalMeta.EventID, nil - } - return pdu.calculateEventID(roomVersion, '$') -} - -func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { - if err := pdu.FillContentHash(); err != nil { - return [32]byte{}, err - } - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { - referenceHash, err := pdu.GetReferenceHash(roomVersion) - if err != nil { - return "", err - } - eventID := make([]byte, 44) - eventID[0] = prefix - switch roomVersion.EventIDFormat() { - case id.EventIDFormatCustom: - return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") - case id.EventIDFormatBase64: - base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) - case id.EventIDFormatURLSafeBase64: - base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) - default: - return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) - } - return id.EventID(eventID), nil -} diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go deleted file mode 100644 index 17417e12..00000000 --- a/federation/pdu/hash_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "encoding/base64" - "testing" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/exerrors" -) - -func TestPDU_CalculateContentHash(t *testing.T) { - for _, test := range testPDUs { - if test.redacted { - continue - } - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - contentHash := exerrors.Must(parsed.CalculateContentHash()) - assert.Equal( - t, - base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), - base64.RawStdEncoding.EncodeToString(contentHash[:]), - ) - }) - } -} - -func TestPDU_VerifyContentHash(t *testing.T) { - for _, test := range testPDUs { - if test.redacted { - continue - } - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - assert.True(t, parsed.VerifyContentHash()) - }) - } -} - -func TestPDU_GetEventID(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) - assert.Equal(t, test.eventID, gotEventID) - }) - } -} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go deleted file mode 100644 index 17db6995..00000000 --- a/federation/pdu/pdu.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "bytes" - "crypto/ed25519" - "encoding/json/jsontext" - "encoding/json/v2" - "errors" - "fmt" - "strings" - "time" - - "github.com/tidwall/gjson" - "go.mau.fi/util/jsonbytes" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU. -// -// The input time is the timestamp of the event. The function should attempt to fetch a key that is -// valid at or after this time, but if that is not possible, the latest available key should be -// returned without an error. The verify function will do its own validity checking based on the -// returned valid until timestamp. -type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) - -type AnyPDU interface { - GetRoomID() (id.RoomID, error) - GetEventID(roomVersion id.RoomVersion) (id.EventID, error) - GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) - CalculateContentHash() ([32]byte, error) - FillContentHash() error - VerifyContentHash() bool - Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error - VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error - ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) - AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) -} - -var ( - _ AnyPDU = (*PDU)(nil) - _ AnyPDU = (*RoomV1PDU)(nil) -) - -type InternalMeta struct { - EventID id.EventID `json:"event_id,omitempty"` - Rejected bool `json:"rejected,omitempty"` - Extra map[string]any `json:",unknown"` -} - -type PDU struct { - AuthEvents []id.EventID `json:"auth_events"` - Content jsontext.Value `json:"content"` - Depth int64 `json:"depth"` - Hashes *Hashes `json:"hashes,omitzero"` - OriginServerTS int64 `json:"origin_server_ts"` - PrevEvents []id.EventID `json:"prev_events"` - Redacts *id.EventID `json:"redacts,omitzero"` - RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events - Sender id.UserID `json:"sender"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` - StateKey *string `json:"state_key,omitzero"` - Type string `json:"type"` - Unsigned jsontext.Value `json:"unsigned,omitzero"` - InternalMeta InternalMeta `json:"-"` - - Unknown jsontext.Value `json:",unknown"` - - // Deprecated legacy fields - DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` - DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` - DeprecatedMembership jsontext.Value `json:"membership,omitzero"` -} - -var ErrPDUIsNil = errors.New("PDU is nil") - -type Hashes struct { - SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` - - Unknown jsontext.Value `json:",unknown"` -} - -func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { - if pdu.Type == "m.room.create" && roomVersion == "" { - roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str) - } - evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} - if pdu.StateKey != nil { - evtType.Class = event.StateEventType - } - eventID, err := pdu.GetEventID(roomVersion) - if err != nil { - return nil, err - } - roomID := pdu.RoomID - if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() { - roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1)) - } - evt := &event.Event{ - StateKey: pdu.StateKey, - Sender: pdu.Sender, - Type: evtType, - Timestamp: pdu.OriginServerTS, - ID: eventID, - RoomID: roomID, - Redacts: ptr.Val(pdu.Redacts), - } - err = json.Unmarshal(pdu.Content, &evt.Content) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal content: %w", err) - } - return evt, nil -} - -func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) { - if signature == "" { - return - } - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = signature -} - -func marshalCanonical(data any) (jsontext.Value, error) { - marshaledBytes, err := json.Marshal(data) - if err != nil { - return nil, err - } - marshaled := jsontext.Value(marshaledBytes) - err = marshaled.Canonicalize() - if err != nil { - return nil, err - } - check := canonicaljson.CanonicalJSONAssumeValid(marshaled) - if !bytes.Equal(marshaled, check) { - fmt.Println(string(marshaled)) - fmt.Println(string(check)) - return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled)) - } - return marshaled, nil -} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go deleted file mode 100644 index 59d7c3a6..00000000 --- a/federation/pdu/pdu_test.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "encoding/json/v2" - "time" - - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -type serverKey struct { - key id.SigningKey - validUntilTS time.Time -} - -type serverDetails struct { - serverName string - keys map[id.KeyID]serverKey -} - -func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { - if serverName != sd.serverName { - return "", time.Time{}, nil - } - key, ok := sd.keys[keyID] - if ok { - return key.key, key.validUntilTS, nil - } - return "", time.Time{}, nil -} - -var mauniumNet = serverDetails{ - serverName: "maunium.net", - keys: map[id.KeyID]serverKey{ - "ed25519:a_xxeS": { - key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg", - validUntilTS: time.Now(), - }, - }, -} -var envsNet = serverDetails{ - serverName: "envs.net", - keys: map[id.KeyID]serverKey{ - "ed25519:a_zIqy": { - key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0", - validUntilTS: time.UnixMilli(1722360538068), - }, - "ed25519:wuJyKT": { - key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0", - validUntilTS: time.Now(), - }, - }, -} -var matrixOrg = serverDetails{ - serverName: "matrix.org", - keys: map[id.KeyID]serverKey{ - "ed25519:auto": { - key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw", - validUntilTS: time.UnixMilli(1576767829750), - }, - "ed25519:a_RXGa": { - key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ", - validUntilTS: time.Now(), - }, - }, -} -var continuwuityOrg = serverDetails{ - serverName: "continuwuity.org", - keys: map[id.KeyID]serverKey{ - "ed25519:PwHlNsFu": { - key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI", - validUntilTS: time.Now(), - }, - }, -} -var novaAstraltechOrg = serverDetails{ - serverName: "nova.astraltech.org", - keys: map[id.KeyID]serverKey{ - "ed25519:a_afpo": { - key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o", - validUntilTS: time.Now(), - }, - }, -} - -type testPDU struct { - name string - pdu string - eventID id.EventID - roomVersion id.RoomVersion - redacted bool - serverDetails -} - -var roomV4MessageTestPDU = testPDU{ - name: "m.room.message in v4 room", - pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`, - eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA", - roomVersion: id.RoomV4, - serverDetails: mauniumNet, -} - -var roomV12MessageTestPDU = testPDU{ - name: "m.room.message in v12 room", - pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, - eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -} - -var testPDUs = []testPDU{roomV4MessageTestPDU, { - name: "m.room.message in v5 room", - pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`, - eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA", - roomVersion: id.RoomV5, - serverDetails: mauniumNet, -}, { - name: "m.room.message in v10 room", - pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`, - eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY", - roomVersion: id.RoomV10, - serverDetails: mauniumNet, -}, { - name: "m.room.message in v11 room", - pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`, - eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`, - roomVersion: id.RoomV11, - serverDetails: mauniumNet, -}, roomV12MessageTestPDU, { - name: "m.room.create in v4 room", - pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`, - eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY", - roomVersion: id.RoomV4, - serverDetails: matrixOrg, -}, { - name: "m.room.create in v10 room", - pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`, - eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ", - roomVersion: id.RoomV10, - serverDetails: envsNet, -}, { - name: "m.room.create in v12 room", - pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`, - eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -}, { - name: "m.room.member in v4 room", - pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`, - eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo", - roomVersion: id.RoomV4, - serverDetails: mauniumNet, -}, { - name: "m.room.member in v10 room", - pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`, - eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs", - roomVersion: id.RoomV10, - serverDetails: mauniumNet, -}, { - name: "m.room.member of creator in v12 room", - pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`, - eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -}, { - name: "custom message event in v4 room", - pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`, - eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0", - roomVersion: id.RoomV4, - serverDetails: mauniumNet, -}, { - name: "redacted m.room.member event in v11 room with 2 signatures", - pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`, - eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ", - roomVersion: id.RoomV11, - serverDetails: novaAstraltechOrg, - redacted: true, -}} - -func parsePDU(pdu string) (out *pdu.PDU) { - exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) - return -} diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go deleted file mode 100644 index d7ee0c15..00000000 --- a/federation/pdu/redact.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "encoding/json/jsontext" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.mau.fi/util/exgjson" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/id" -) - -func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value { - filtered := jsontext.Value("{}") - var err error - for _, path := range allowedPaths { - res := gjson.GetBytes(object, path) - if res.Exists() { - var raw jsontext.Value - if res.Index > 0 { - raw = object[res.Index : res.Index+len(res.Raw)] - } else { - raw = jsontext.Value(res.Raw) - } - filtered, err = sjson.SetRawBytes(filtered, path, raw) - if err != nil { - panic(err) - } - } - } - return filtered -} - -func (pdu *PDU) Clone() *PDU { - return ptr.Clone(pdu) -} - -func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU { - pdu.Signatures = nil - return pdu.Redact(roomVersion) -} - -var emptyObject = jsontext.Value("{}") - -func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value { - switch eventType { - case "m.room.member": - allowedPaths := []string{"membership"} - if roomVersion.RestrictedJoinsFix() { - allowedPaths = append(allowedPaths, "join_authorised_via_users_server") - } - if roomVersion.UpdatedRedactionRules() { - allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed")) - } - return filteredObject(content, allowedPaths...) - case "m.room.create": - if !roomVersion.UpdatedRedactionRules() { - return filteredObject(content, "creator") - } - return content - case "m.room.join_rules": - if roomVersion.RestrictedJoins() { - return filteredObject(content, "join_rule", "allow") - } - return filteredObject(content, "join_rule") - case "m.room.power_levels": - allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"} - if roomVersion.UpdatedRedactionRules() { - allowedKeys = append(allowedKeys, "invite") - } - return filteredObject(content, allowedKeys...) - case "m.room.history_visibility": - return filteredObject(content, "history_visibility") - case "m.room.redaction": - if roomVersion.RedactsInContent() { - return filteredObject(content, "redacts") - } - return emptyObject - case "m.room.aliases": - if roomVersion.SpecialCasedAliasesAuth() { - return filteredObject(content, "aliases") - } - return emptyObject - default: - return emptyObject - } -} - -func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { - pdu.Unknown = nil - pdu.Unsigned = nil - if roomVersion.UpdatedRedactionRules() { - pdu.DeprecatedPrevState = nil - pdu.DeprecatedOrigin = nil - pdu.DeprecatedMembership = nil - } - if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() { - pdu.Redacts = nil - } - pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) - return pdu -} diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go deleted file mode 100644 index 04e7c5ef..00000000 --- a/federation/pdu/signature.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "crypto/ed25519" - "encoding/base64" - "fmt" - "time" - - "maunium.net/go/mautrix/federation/signutil" - "maunium.net/go/mautrix/id" -) - -func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { - err := pdu.FillContentHash() - if err != nil { - return err - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) - } - signature := ed25519.Sign(privateKey, rawJSON) - pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature)) - return nil -} - -func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) - } - verified := false - for keyID, sig := range pdu.Signatures[serverName] { - originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, validUntil, err := getKey(serverName, keyID, originServerTS) - if err != nil { - return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) - } else if key == "" { - return fmt.Errorf("key %s not found for %s", keyID, serverName) - } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { - return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) - } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { - return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) - } else { - verified = true - } - } - if !verified { - return fmt.Errorf("no verifiable signatures found for server %s", serverName) - } - return nil -} diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go deleted file mode 100644 index 01df5076..00000000 --- a/federation/pdu/signature_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "crypto/ed25519" - "encoding/base64" - "encoding/json/jsontext" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -func TestPDU_VerifySignature(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) - assert.NoError(t, err) - }) - } -} - -func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { - test := roomV12MessageTestPDU - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - return - }) - assert.Error(t, err) -} - -func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { - test := roomV4MessageTestPDU - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - key = test.keys[keyID].key - validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - return - }) - assert.NoError(t, err) -} - -func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { - test := roomV12MessageTestPDU - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - key = test.keys[keyID].key - validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - return - }) - assert.Error(t, err) -} - -func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) { - test := roomV12MessageTestPDU - parsed := parsePDU(test.pdu) - for _, sigs := range parsed.Signatures { - for key := range sigs { - sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC" - } - } - err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) - assert.Error(t, err) -} - -func TestPDU_Sign(t *testing.T) { - pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil)) - evt := &pdu.PDU{ - AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"}, - Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`), - Depth: 123, - OriginServerTS: 1755384351627, - PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"}, - RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - Sender: "@tulir:example.com", - Type: "m.room.message", - } - err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) - require.NoError(t, err) - err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - if serverName == "example.com" && keyID == "ed25519:rand" { - key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) - validUntil = time.Now() - } - return - }) - require.NoError(t, err) - -} diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go deleted file mode 100644 index 9557f8ab..00000000 --- a/federation/pdu/v1.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "crypto/ed25519" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json/jsontext" - "encoding/json/v2" - "fmt" - "time" - - "github.com/tidwall/gjson" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/federation/signutil" - "maunium.net/go/mautrix/id" -) - -type V1EventReference struct { - ID id.EventID - Hashes Hashes -} - -var ( - _ json.UnmarshalerFrom = (*V1EventReference)(nil) - _ json.MarshalerTo = (*V1EventReference)(nil) -) - -func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error { - return json.MarshalEncode(enc, []any{er.ID, er.Hashes}) -} - -func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error { - var ref V1EventReference - var data []jsontext.Value - if err := json.UnmarshalDecode(dec, &data); err != nil { - return err - } else if len(data) != 2 { - return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data)) - } else if err = json.Unmarshal(data[0], &ref.ID); err != nil { - return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err) - } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil { - return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err) - } - *er = ref - return nil -} - -type RoomV1PDU struct { - AuthEvents []V1EventReference `json:"auth_events"` - Content jsontext.Value `json:"content"` - Depth int64 `json:"depth"` - EventID id.EventID `json:"event_id"` - Hashes *Hashes `json:"hashes,omitzero"` - OriginServerTS int64 `json:"origin_server_ts"` - PrevEvents []V1EventReference `json:"prev_events"` - Redacts *id.EventID `json:"redacts,omitzero"` - RoomID id.RoomID `json:"room_id"` - Sender id.UserID `json:"sender"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` - StateKey *string `json:"state_key,omitzero"` - Type string `json:"type"` - Unsigned jsontext.Value `json:"unsigned,omitzero"` - - Unknown jsontext.Value `json:",unknown"` - - // Deprecated legacy fields - DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` - DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` - DeprecatedMembership jsontext.Value `json:"membership,omitzero"` -} - -func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) { - return pdu.RoomID, nil -} - -func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { - if !pdu.SupportsRoomVersion(roomVersion) { - return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion) - } - return pdu.EventID, nil -} - -func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU { - pdu.Signatures = nil - return pdu.Redact(roomVersion) -} - -func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU { - pdu.Unknown = nil - pdu.Unsigned = nil - if pdu.Type != "m.room.redaction" { - pdu.Redacts = nil - } - pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) - return pdu -} - -func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { - if !pdu.SupportsRoomVersion(roomVersion) { - return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion) - } - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { - if err := pdu.FillContentHash(); err != nil { - return [32]byte{}, err - } - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - pduClone := pdu.Clone() - pduClone.Signatures = nil - pduClone.Unsigned = nil - pduClone.Hashes = nil - rawJSON, err := marshalCanonical(pduClone) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *RoomV1PDU) FillContentHash() error { - if pdu == nil { - return ErrPDUIsNil - } else if pdu.Hashes != nil { - return nil - } else if hash, err := pdu.CalculateContentHash(); err != nil { - return err - } else { - pdu.Hashes = &Hashes{SHA256: hash[:]} - return nil - } -} - -func (pdu *RoomV1PDU) VerifyContentHash() bool { - if pdu == nil || pdu.Hashes == nil { - return false - } - calculatedHash, err := pdu.CalculateContentHash() - if err != nil { - return false - } - return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) -} - -func (pdu *RoomV1PDU) Clone() *RoomV1PDU { - return ptr.Clone(pdu) -} - -func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { - if !pdu.SupportsRoomVersion(roomVersion) { - return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion) - } - err := pdu.FillContentHash() - if err != nil { - return err - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) - } - signature := ed25519.Sign(privateKey, rawJSON) - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) - return nil -} - -func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { - if !pdu.SupportsRoomVersion(roomVersion) { - return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion) - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) - } - verified := false - for keyID, sig := range pdu.Signatures[serverName] { - originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, _, err := getKey(serverName, keyID, originServerTS) - if err != nil { - return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) - } else if key == "" { - return fmt.Errorf("key %s not found for %s", keyID, serverName) - } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { - return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) - } else { - verified = true - } - } - if !verified { - return fmt.Errorf("no verifiable signatures found for server %s", serverName) - } - return nil -} - -func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool { - switch roomVersion { - case id.RoomV0, id.RoomV1, id.RoomV2: - return true - default: - return false - } -} - -func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { - if !pdu.SupportsRoomVersion(roomVersion) { - return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion) - } - evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} - if pdu.StateKey != nil { - evtType.Class = event.StateEventType - } - evt := &event.Event{ - StateKey: pdu.StateKey, - Sender: pdu.Sender, - Type: evtType, - Timestamp: pdu.OriginServerTS, - ID: pdu.EventID, - RoomID: pdu.RoomID, - Redacts: ptr.Val(pdu.Redacts), - } - err := json.Unmarshal(pdu.Content, &evt.Content) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal content: %w", err) - } - return evt, nil -} - -func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) { - if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { - return AuthEventSelection{} - } - keys = make(AuthEventSelection, 0, 3) - keys.Add(event.StateCreate.Type, "") - keys.Add(event.StatePowerLevels.Type, "") - keys.Add(event.StateMember.Type, pdu.Sender.String()) - if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { - keys.Add(event.StateMember.Type, *pdu.StateKey) - membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) - if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { - keys.Add(event.StateJoinRules.Type, "") - } - if membership == event.MembershipInvite { - thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str - if thirdPartyInviteToken != "" { - keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) - } - } - } - return -} diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go deleted file mode 100644 index ecf2dbd2..00000000 --- a/federation/pdu/v1_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "encoding/base64" - "encoding/json/v2" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -var testV1PDUs = []testPDU{{ - name: "m.room.message in v1 room", - pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`, - eventID: "$16738169022163bokdi:maunium.net", - roomVersion: id.RoomV1, - serverDetails: mauniumNet, -}, { - name: "m.room.create in v1 room", - pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`, - eventID: "$143454825711DhCxH:matrix.org", - roomVersion: id.RoomV1, - serverDetails: matrixOrg, -}, { - name: "m.room.member in v1 room", - pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`, - eventID: "$15660585693272iEryv:maunium.net", - roomVersion: id.RoomV1, - serverDetails: mauniumNet, -}} - -func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) { - exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) - return -} - -func TestRoomV1PDU_CalculateContentHash(t *testing.T) { - for _, test := range testV1PDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parseV1PDU(test.pdu) - contentHash := exerrors.Must(parsed.CalculateContentHash()) - assert.Equal( - t, - base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), - base64.RawStdEncoding.EncodeToString(contentHash[:]), - ) - }) - } -} - -func TestRoomV1PDU_VerifyContentHash(t *testing.T) { - for _, test := range testV1PDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parseV1PDU(test.pdu) - assert.True(t, parsed.VerifyContentHash()) - }) - } -} - -func TestRoomV1PDU_VerifySignature(t *testing.T) { - for _, test := range testV1PDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parseV1PDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { - key, ok := test.keys[keyID] - if ok { - return key.key, key.validUntilTS, nil - } - return "", time.Time{}, nil - }) - assert.NoError(t, err) - }) - } -} diff --git a/federation/resolution.go b/federation/resolution.go index a3188266..e6785988 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -11,7 +11,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net" "net/http" "net/url" @@ -20,8 +19,6 @@ import ( "time" "github.com/rs/zerolog" - - "maunium.net/go/mautrix" ) type ResolvedServerName struct { @@ -80,10 +77,7 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS } else if wellKnown != nil { output.Expires = expiry output.HostHeader = wellKnown.Server - wkHost, wkPort, ok := ParseServerName(wellKnown.Server) - if ok { - hostname, port = wkHost, wkPort - } + hostname, port, ok = ParseServerName(wellKnown.Server) // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known if net.ParseIP(hostname) != nil || port != 0 { if port == 0 { @@ -125,38 +119,6 @@ func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net return target, err } -func parseCacheControl(resp *http.Response) time.Duration { - cc := resp.Header.Get("Cache-Control") - if cc == "" { - return 0 - } - parts := strings.Split(cc, ",") - for _, part := range parts { - kv := strings.SplitN(strings.TrimSpace(part), "=", 1) - switch kv[0] { - case "no-cache", "no-store": - return 0 - case "max-age": - if len(kv) < 2 { - continue - } - maxAge, err := strconv.Atoi(kv[1]) - if err != nil || maxAge < 0 { - continue - } - age, _ := strconv.Atoi(resp.Header.Get("Age")) - return time.Duration(maxAge-age) * time.Second - } - } - return 0 -} - -const ( - MinCacheDuration = 1 * time.Hour - MaxCacheDuration = 72 * time.Hour - DefaultCacheDuration = 24 * time.Hour -) - // RequestWellKnown sends a request to the well-known endpoint of a server and returns the response, // plus the time when the cache should expire. func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) { @@ -176,23 +138,14 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) - } else if resp.ContentLength > mautrix.WellKnownMaxSize { - return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength) } var respData RespWellKnown - err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData) + err = json.NewDecoder(resp.Body).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { return nil, time.Time{}, errors.New("server name not found in response") } - cacheDuration := parseCacheControl(resp) - if cacheDuration <= 0 { - cacheDuration = DefaultCacheDuration - } else if cacheDuration < MinCacheDuration { - cacheDuration = MinCacheDuration - } else if cacheDuration > MaxCacheDuration { - cacheDuration = MaxCacheDuration - } + // TODO parse cache-control header return &respData, time.Now().Add(24 * time.Hour), nil } diff --git a/federation/serverauth.go b/federation/serverauth.go deleted file mode 100644 index cd300341..00000000 --- a/federation/serverauth.go +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package federation - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "maps" - "net/http" - "slices" - "strings" - "sync" - - "github.com/rs/zerolog" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/id" -) - -type ServerAuth struct { - Keys KeyCache - Client *Client - GetDestination func(XMatrixAuth) string - MaxBodySize int64 - - keyFetchLocks map[string]*sync.Mutex - keyFetchLocksLock sync.Mutex -} - -func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth { - return &ServerAuth{ - Keys: keyCache, - Client: client, - GetDestination: getDestination, - MaxBodySize: 50 * 1024 * 1024, - keyFetchLocks: make(map[string]*sync.Mutex), - } -} - -var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized} - -var ( - errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header") - errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix") - errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components") - errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header") - errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys") - errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures") - errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large") - errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON") - errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body") - errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature") -) - -type XMatrixAuth struct { - Origin string - Destination string - KeyID id.KeyID - Signature string -} - -func (xma XMatrixAuth) String() string { - return fmt.Sprintf( - `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, - xma.Origin, - xma.Destination, - xma.KeyID, - xma.Signature, - ) -} - -func ParseXMatrixAuth(auth string) (xma XMatrixAuth) { - auth = strings.TrimPrefix(auth, "X-Matrix ") - // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum - for _, part := range strings.Split(auth, ",") { - part = strings.TrimSpace(part) - eqIdx := strings.Index(part, "=") - if eqIdx == -1 || strings.Count(part, "=") > 1 { - continue - } - val := strings.Trim(part[eqIdx+1:], "\"") - switch strings.ToLower(part[:eqIdx]) { - case "origin": - xma.Origin = val - case "destination": - xma.Destination = val - case "key": - xma.KeyID = id.KeyID(val) - case "sig": - xma.Signature = val - } - } - return -} - -func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) { - res, err := sa.Keys.LoadKeys(serverName) - if err != nil { - return nil, fmt.Errorf("failed to read cache: %w", err) - } else if res.HasKey(keyID) { - return res, nil - } - - sa.keyFetchLocksLock.Lock() - lock, ok := sa.keyFetchLocks[serverName] - if !ok { - lock = &sync.Mutex{} - sa.keyFetchLocks[serverName] = lock - } - sa.keyFetchLocksLock.Unlock() - - lock.Lock() - defer lock.Unlock() - res, err = sa.Keys.LoadKeys(serverName) - if err != nil { - return nil, fmt.Errorf("failed to read cache: %w", err) - } else if res != nil { - if res.HasKey(keyID) { - return res, nil - } else if !sa.Keys.ShouldReQuery(serverName) { - zerolog.Ctx(ctx).Trace(). - Str("server_name", serverName). - Stringer("key_id", keyID). - Msg("Not sending key request for missing key ID, last query was too recent") - return res, nil - } - } - res, err = sa.Client.ServerKeys(ctx, serverName) - if err != nil { - sa.Keys.StoreFetchError(serverName, err) - return nil, err - } - sa.Keys.StoreKeys(res) - return res, nil -} - -type fixedLimitedReader struct { - R io.Reader - N int64 - Err error -} - -func (l *fixedLimitedReader) Read(p []byte) (n int, err error) { - if l.N <= 0 { - return 0, l.Err - } - if int64(len(p)) > l.N { - p = p[0:l.N] - } - n, err = l.R.Read(p) - l.N -= int64(n) - return -} - -func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) { - defer func() { - _ = r.Body.Close() - }() - log := zerolog.Ctx(r.Context()) - if r.ContentLength > sa.MaxBodySize { - return nil, &errRequestBodyTooLarge - } - auth := r.Header.Get("Authorization") - if auth == "" { - return nil, &errMissingAuthHeader - } else if !strings.HasPrefix(auth, "X-Matrix ") { - return nil, &errInvalidAuthHeader - } - parsed := ParseXMatrixAuth(auth) - if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" { - log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header") - return nil, &errMalformedAuthHeader - } - destination := sa.GetDestination(parsed) - if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) { - log.Trace(). - Str("got_destination", parsed.Destination). - Str("expected_destination", destination). - Msg("Invalid destination in X-Matrix header") - return nil, &errInvalidDestination - } - resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID) - if err != nil { - if !errors.Is(err, ErrRecentKeyQueryFailed) { - log.Err(err). - Str("server_name", parsed.Origin). - Msg("Failed to query keys to authenticate request") - } else { - log.Trace().Err(err). - Str("server_name", parsed.Origin). - Msg("Failed to query keys to authenticate request (cached error)") - } - return nil, &errFailedToQueryKeys - } else if err := resp.VerifySelfSignature(); err != nil { - log.Trace().Err(err). - Str("server_name", parsed.Origin). - Msg("Failed to validate self-signatures of server keys") - return nil, &errInvalidSelfSignatures - } - key, ok := resp.VerifyKeys[parsed.KeyID] - if !ok { - keys := slices.Collect(maps.Keys(resp.VerifyKeys)) - log.Trace(). - Stringer("expected_key_id", parsed.KeyID). - Any("found_key_ids", keys). - Msg("Didn't find expected key ID to verify request") - return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys)) - } - var reqBody []byte - if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead { - reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge}) - if errors.Is(err, errRequestBodyTooLarge) { - return nil, &errRequestBodyTooLarge - } else if err != nil { - log.Err(err). - Str("server_name", parsed.Origin). - Msg("Failed to read request body to authenticate") - return nil, &errBodyReadFailed - } else if !json.Valid(reqBody) { - return nil, &errInvalidJSONBody - } - } - err = (&signableRequest{ - Method: r.Method, - URI: r.URL.RequestURI(), - Origin: parsed.Origin, - Destination: destination, - Content: reqBody, - }).Verify(key.Key, parsed.Signature) - if err != nil { - log.Trace().Err(err).Msg("Request has invalid signature") - return nil, &errInvalidRequestSignature - } - ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) - ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin) - ctx = log.With(). - Str("origin_server_name", parsed.Origin). - Str("destination_server_name", destination). - Logger().WithContext(ctx) - modifiedReq := r.WithContext(ctx) - if reqBody != nil { - modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) - } - return modifiedReq, nil -} - -func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if modifiedReq, err := sa.Authenticate(r); err != nil { - err.Write(w) - } else { - next.ServeHTTP(w, modifiedReq) - } - }) -} diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go deleted file mode 100644 index f99fc6cf..00000000 --- a/federation/serverauth_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package federation_test - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/federation" -) - -func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { - cli := federation.NewClient("", nil, nil) - ctx := context.Background() - for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} { - t.Run(name, func(t *testing.T) { - resp, err := cli.ServerKeys(ctx, name) - require.NoError(t, err) - assert.NoError(t, resp.VerifySelfSignature()) - }) - } -} diff --git a/federation/signingkey.go b/federation/signingkey.go index a4ad9679..67751b48 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -14,11 +14,9 @@ import ( "strings" "time" - "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -33,8 +31,8 @@ type SigningKey struct { // // The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function. func (sk *SigningKey) SynapseString() string { - alg, keyID := sk.ID.Parse() - return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) + alg, id := sk.ID.Parse() + return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) } // ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey. @@ -79,37 +77,6 @@ type ServerKeyResponse struct { OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"` Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"` - - Raw json.RawMessage `json:"-"` -} - -type QueryKeysResponse struct { - ServerKeys []*ServerKeyResponse `json:"server_keys"` -} - -func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { - if skr == nil { - return false - } else if _, ok := skr.VerifyKeys[keyID]; ok { - return true - } - return false -} - -func (skr *ServerKeyResponse) VerifySelfSignature() error { - for keyID, key := range skr.VerifyKeys { - if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil { - return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err) - } - } - return nil -} - -type marshalableSKR ServerKeyResponse - -func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error { - skr.Raw = data - return json.Unmarshal(data, (*marshalableSKR)(skr)) } type ServerVerifyKey struct { @@ -125,16 +92,12 @@ type OldVerifyKey struct { ExpiredTS jsontime.UnixMilli `json:"expired_ts"` } -func (sk *SigningKey) SignJSON(data any) (string, error) { +func (sk *SigningKey) SignJSON(data any) ([]byte, error) { marshaled, err := json.Marshal(data) if err != nil { - return "", err + return nil, err } - marshaled, err = sjson.DeleteBytes(marshaled, "signatures") - if err != nil { - return "", err - } - return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil + return sk.SignRawJSON(marshaled), nil } func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte { @@ -157,7 +120,7 @@ func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[i } skr.Signatures = map[string]map[id.KeyID]string{ serverName: { - sk.ID: signature, + sk.ID: base64.RawURLEncoding.EncodeToString(signature), }, } return skr diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go deleted file mode 100644 index ea0e7886..00000000 --- a/federation/signutil/verify.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package signutil - -import ( - "crypto/ed25519" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.mau.fi/util/exgjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/id" -) - -var ErrSignatureNotFound = errors.New("signature not found") -var ErrInvalidSignature = errors.New("invalid signature") - -func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error { - var err error - message, ok := data.(json.RawMessage) - if !ok { - message, err = json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal data: %w", err) - } - } - sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) - if sigVal.Type != gjson.String { - return ErrSignatureNotFound - } - message, err = sjson.DeleteBytes(message, "signatures") - if err != nil { - return fmt.Errorf("failed to delete signatures: %w", err) - } - message, err = sjson.DeleteBytes(message, "unsigned") - if err != nil { - return fmt.Errorf("failed to delete unsigned: %w", err) - } - return VerifyJSONRaw(key, sigVal.Str, message) -} - -func VerifyJSONAny(key id.SigningKey, data any) error { - var err error - message, ok := data.(json.RawMessage) - if !ok { - message, err = json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal data: %w", err) - } - } - sigs := gjson.GetBytes(message, "signatures") - if !sigs.IsObject() { - return ErrSignatureNotFound - } - message, err = sjson.DeleteBytes(message, "signatures") - if err != nil { - return fmt.Errorf("failed to delete signatures: %w", err) - } - message, err = sjson.DeleteBytes(message, "unsigned") - if err != nil { - return fmt.Errorf("failed to delete unsigned: %w", err) - } - var validated bool - sigs.ForEach(func(_, value gjson.Result) bool { - if !value.IsObject() { - return true - } - value.ForEach(func(_, value gjson.Result) bool { - if value.Type != gjson.String { - return true - } - validated = VerifyJSONRaw(key, value.Str, message) == nil - return !validated - }) - return !validated - }) - if !validated { - return ErrInvalidSignature - } - return nil -} - -func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { - sigBytes, err := base64.RawStdEncoding.DecodeString(sig) - if err != nil { - return fmt.Errorf("failed to decode signature: %w", err) - } - keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) - if err != nil { - return fmt.Errorf("failed to decode key: %w", err) - } - message = canonicaljson.CanonicalJSONAssumeValid(message) - if !ed25519.Verify(keyBytes, message, sigBytes) { - return ErrInvalidSignature - } - return nil -} diff --git a/filter.go b/filter.go index 54973dab..fd6de7a0 100644 --- a/filter.go +++ b/filter.go @@ -19,45 +19,43 @@ const ( // Filter is used by clients to specify how the server should filter responses to e.g. sync requests // Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering type Filter struct { - AccountData *FilterPart `json:"account_data,omitempty"` + AccountData FilterPart `json:"account_data,omitempty"` EventFields []string `json:"event_fields,omitempty"` EventFormat EventFormat `json:"event_format,omitempty"` - Presence *FilterPart `json:"presence,omitempty"` - Room *RoomFilter `json:"room,omitempty"` - - BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"` + Presence FilterPart `json:"presence,omitempty"` + Room RoomFilter `json:"room,omitempty"` } // RoomFilter is used to define filtering rules for room events type RoomFilter struct { - AccountData *FilterPart `json:"account_data,omitempty"` - Ephemeral *FilterPart `json:"ephemeral,omitempty"` + AccountData FilterPart `json:"account_data,omitempty"` + Ephemeral FilterPart `json:"ephemeral,omitempty"` IncludeLeave bool `json:"include_leave,omitempty"` NotRooms []id.RoomID `json:"not_rooms,omitempty"` Rooms []id.RoomID `json:"rooms,omitempty"` - State *FilterPart `json:"state,omitempty"` - Timeline *FilterPart `json:"timeline,omitempty"` + State FilterPart `json:"state,omitempty"` + Timeline FilterPart `json:"timeline,omitempty"` } // FilterPart is used to define filtering rules for specific categories of events type FilterPart struct { - NotRooms []id.RoomID `json:"not_rooms,omitempty"` - Rooms []id.RoomID `json:"rooms,omitempty"` - Limit int `json:"limit,omitempty"` - NotSenders []id.UserID `json:"not_senders,omitempty"` - NotTypes []event.Type `json:"not_types,omitempty"` - Senders []id.UserID `json:"senders,omitempty"` - Types []event.Type `json:"types,omitempty"` - ContainsURL *bool `json:"contains_url,omitempty"` - LazyLoadMembers bool `json:"lazy_load_members,omitempty"` - IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` - UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"` + NotRooms []id.RoomID `json:"not_rooms,omitempty"` + Rooms []id.RoomID `json:"rooms,omitempty"` + Limit int `json:"limit,omitempty"` + NotSenders []id.UserID `json:"not_senders,omitempty"` + NotTypes []event.Type `json:"not_types,omitempty"` + Senders []id.UserID `json:"senders,omitempty"` + Types []event.Type `json:"types,omitempty"` + ContainsURL *bool `json:"contains_url,omitempty"` + + LazyLoadMembers bool `json:"lazy_load_members,omitempty"` + IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` } // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("bad event_format value") + return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") } return nil } @@ -69,7 +67,7 @@ func DefaultFilter() Filter { EventFields: nil, EventFormat: "client", Presence: DefaultFilterPart(), - Room: &RoomFilter{ + Room: RoomFilter{ AccountData: DefaultFilterPart(), Ephemeral: DefaultFilterPart(), IncludeLeave: false, @@ -82,8 +80,8 @@ func DefaultFilter() Filter { } // DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request -func DefaultFilterPart() *FilterPart { - return &FilterPart{ +func DefaultFilterPart() FilterPart { + return FilterPart{ NotRooms: nil, Rooms: nil, Limit: 20, diff --git a/format/htmlparser.go b/format/htmlparser.go index e0507d93..d099e8a7 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -13,7 +13,6 @@ import ( "strconv" "strings" - "go.mau.fi/util/exstrings" "golang.org/x/net/html" "maunium.net/go/mautrix/event" @@ -67,7 +66,6 @@ type LinkConverter func(text, href string, ctx Context) string type ColorConverter func(text, fg, bg string, ctx Context) string type CodeBlockConverter func(code, language string, ctx Context) string type PillConverter func(displayname, mxid, eventID string, ctx Context) string -type ImageConverter func(src, alt, title, width, height string, isEmoji bool) string const ContextKeyMentions = "_mentions" @@ -93,30 +91,6 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string } } -func onlyBacktickCount(line string) (count int) { - for i := 0; i < len(line); i++ { - if line[i] != '`' { - return -1 - } - count++ - } - return -} - -func DefaultMonospaceBlockConverter(code, language string, ctx Context) string { - if len(code) == 0 || code[len(code)-1] != '\n' { - code += "\n" - } - fence := "```" - for line := range strings.SplitSeq(code, "\n") { - count := onlyBacktickCount(strings.TrimSpace(line)) - if count >= len(fence) { - fence = strings.Repeat("`", count+1) - } - } - return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence) -} - // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter @@ -127,15 +101,12 @@ type HTMLParser struct { ItalicConverter TextConverter StrikethroughConverter TextConverter UnderlineConverter TextConverter - MathConverter TextConverter - MathBlockConverter TextConverter LinkConverter LinkConverter SpoilerConverter SpoilerConverter ColorConverter ColorConverter MonospaceBlockConverter CodeBlockConverter MonospaceConverter TextConverter TextConverter TextConverter - ImageConverter ImageConverter } // TaggedString is a string that also contains a HTML tag. @@ -212,6 +183,25 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string { return strings.Join(children, "\n") } +func LongestSequence(in string, of rune) int { + currentSeq := 0 + maxSeq := 0 + for _, chr := range in { + if chr == of { + currentSeq++ + } else { + if currentSeq > maxSeq { + maxSeq = currentSeq + } + currentSeq = 0 + } + } + if currentSeq > maxSeq { + maxSeq = currentSeq + } + return maxSeq +} + func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) switch node.Data { @@ -238,23 +228,14 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri if parser.MonospaceConverter != nil { return parser.MonospaceConverter(str, ctx) } - return SafeMarkdownCode(str) + surround := strings.Repeat("`", LongestSequence(str, '`')+1) + return fmt.Sprintf("%s%s%s", surround, str, surround) } return str } func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) - if node.Data == "span" || node.Data == "div" { - math, _ := parser.maybeGetAttribute(node, "data-mx-maths") - if math != "" && parser.MathConverter != nil { - if node.Data == "div" && parser.MathBlockConverter != nil { - str = parser.MathBlockConverter(math, ctx) - } else { - str = parser.MathConverter(math, ctx) - } - } - } if node.Data == "span" { reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler") if isSpoiler { @@ -311,28 +292,12 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string { } if parser.LinkConverter != nil { return parser.LinkConverter(str, href, ctx) - } else if str == href || - str == strings.TrimPrefix(href, "mailto:") || - str == strings.TrimPrefix(href, "http://") || - str == strings.TrimPrefix(href, "https://") { + } else if str == href { return str } return fmt.Sprintf("%s (%s)", str, href) } -func (parser *HTMLParser) imgToString(node *html.Node, ctx Context) string { - src := parser.getAttribute(node, "src") - alt := parser.getAttribute(node, "alt") - title := parser.getAttribute(node, "title") - width := parser.getAttribute(node, "width") - height := parser.getAttribute(node, "height") - _, isEmoji := parser.maybeGetAttribute(node, "data-mx-emoticon") - if parser.ImageConverter != nil { - return parser.ImageConverter(src, alt, title, width, height, isEmoji) - } - return alt -} - func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { ctx = ctx.WithTag(node.Data) switch node.Data { @@ -352,12 +317,8 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { return parser.linkToString(node, ctx) case "p": return parser.nodeToTagAwareString(node.FirstChild, ctx) - case "img": - return parser.imgToString(node, ctx) case "hr": return parser.HorizontalLine - case "input": - return parser.inputToString(node, ctx) case "pre": var preStr, language string if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" { @@ -372,28 +333,20 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } - return DefaultMonospaceBlockConverter(preStr, language, ctx) + if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { + preStr += "\n" + } + return fmt.Sprintf("```%s\n%s```", language, preStr) default: return parser.nodeToTagAwareString(node.FirstChild, ctx) } } -func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string { - if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" { - _, checked := parser.maybeGetAttribute(node, "checked") - if checked { - return "[x]" - } - return "[ ]" - } - return parser.nodeToTagAwareString(node.FirstChild, ctx) -} - func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString { switch node.Type { case html.TextNode: if !ctx.PreserveWhitespace { - node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", "")) + node.Data = strings.Replace(node.Data, "\n", "", -1) } if parser.TextConverter != nil { node.Data = parser.TextConverter(node.Data, ctx) @@ -459,35 +412,6 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string { return parser.nodeToTagAwareString(node, ctx) } -var TextHTMLParser = &HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - HorizontalLine: "\n---\n", - PillConverter: DefaultPillConverter, -} - -var MarkdownHTMLParser = &HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - HorizontalLine: "\n---\n", - PillConverter: DefaultPillConverter, - LinkConverter: func(text, href string, ctx Context) string { - if text == href { - return fmt.Sprintf("<%s>", href) - } - return fmt.Sprintf("[%s](%s)", text, href) - }, - MathConverter: func(s string, c Context) string { - return fmt.Sprintf("$%s$", s) - }, - MathBlockConverter: func(s string, c Context) string { - return fmt.Sprintf("$$\n%s\n$$", s) - }, - UnderlineConverter: func(s string, c Context) string { - return fmt.Sprintf("%s", s) - }, -} - // HTMLToText converts Matrix HTML into text with the default settings. func HTMLToText(html string) string { return (&HTMLParser{ @@ -498,12 +422,20 @@ func HTMLToText(html string) string { }).Parse(html, NewContext(context.TODO())) } -func HTMLToMarkdownFull(parser *HTMLParser, html string) (parsed string, mentions *event.Mentions) { - if parser == nil { - parser = MarkdownHTMLParser - } +func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Mentions) { ctx := NewContext(context.TODO()) - parsed = parser.Parse(html, ctx) + parsed = (&HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + HorizontalLine: "\n---\n", + PillConverter: DefaultPillConverter, + LinkConverter: func(text, href string, ctx Context) string { + if text == href { + return text + } + return fmt.Sprintf("[%s](%s)", text, href) + }, + }).Parse(html, ctx) mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) mentions = &event.Mentions{ UserIDs: mentionList, @@ -515,6 +447,6 @@ func HTMLToMarkdownFull(parser *HTMLParser, html string) (parsed string, mention // // Currently, the only difference to HTMLToText is how links are formatted. func HTMLToMarkdown(html string) string { - parsed, _ := HTMLToMarkdownFull(nil, html) + parsed, _ := HTMLToMarkdownAndMentions(html) return parsed } diff --git a/format/markdown.go b/format/markdown.go index 77ced0dc..11f9f684 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -8,17 +8,14 @@ package format import ( "fmt" - "regexp" "strings" "github.com/yuin/goldmark" "github.com/yuin/goldmark/extension" "github.com/yuin/goldmark/renderer/html" - "go.mau.fi/util/exstrings" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format/mdext" - "maunium.net/go/mautrix/id" ) const paragraphStart = "

      " @@ -42,55 +39,6 @@ func UnwrapSingleParagraph(html string) string { return html } -var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])") - -func EscapeMarkdown(text string) string { - text = mdEscapeRegex.ReplaceAllString(text, "\\$1") - text = strings.ReplaceAll(text, ">", ">") - text = strings.ReplaceAll(text, "<", "<") - return text -} - -type uriAble interface { - String() string - URI() *id.MatrixURI -} - -func MarkdownMention(id uriAble) string { - return MarkdownMentionWithName(id.String(), id) -} - -func MarkdownMentionWithName(name string, id uriAble) string { - return MarkdownLink(name, id.URI().MatrixToURL()) -} - -func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string { - if name == "" { - name = id.String() - } - return MarkdownLink(name, id.URI(via...).MatrixToURL()) -} - -func MarkdownLink(name string, url string) string { - return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url)) -} - -func SafeMarkdownCode[T ~string](textInput T) string { - if textInput == "" { - return "` `" - } - text := strings.ReplaceAll(string(textInput), "\n", " ") - backtickCount := exstrings.LongestSequenceOf(text, '`') - if backtickCount == 0 { - return fmt.Sprintf("`%s`", text) - } - quotes := strings.Repeat("`", backtickCount+1) - if text[0] == '`' || text[len(text)-1] == '`' { - return fmt.Sprintf("%s %s %s", quotes, text, quotes) - } - return fmt.Sprintf("%s%s%s", quotes, text, quotes) -} - func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent { var buf strings.Builder err := renderer.Convert([]byte(text), &buf) @@ -101,16 +49,8 @@ func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.Message return HTMLToContent(htmlBody) } -func TextToContent(text string) event.MessageEventContent { - return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, - Mentions: &event.Mentions{}, - } -} - -func HTMLToContentFull(renderer *HTMLParser, html string) event.MessageEventContent { - text, mentions := HTMLToMarkdownFull(renderer, html) +func HTMLToContent(html string) event.MessageEventContent { + text, mentions := HTMLToMarkdownAndMentions(html) if html != text { return event.MessageEventContent{ FormattedBody: html, @@ -120,11 +60,11 @@ func HTMLToContentFull(renderer *HTMLParser, html string) event.MessageEventCont Mentions: mentions, } } - return TextToContent(text) -} - -func HTMLToContent(html string) event.MessageEventContent { - return HTMLToContentFull(nil, html) + return event.MessageEventContent{ + MsgType: event.MsgText, + Body: text, + Mentions: &event.Mentions{}, + } } func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent { @@ -140,6 +80,10 @@ func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEve htmlBody = strings.Replace(text, "\n", "
      ", -1) return HTMLToContent(htmlBody) } else { - return TextToContent(text) + return event.MessageEventContent{ + MsgType: event.MsgText, + Body: text, + Mentions: &event.Mentions{}, + } } } diff --git a/format/markdown_test.go b/format/markdown_test.go index 46ea4886..10ae270c 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -158,56 +158,3 @@ func TestRenderMarkdown_DiscordUnderline(t *testing.T) { assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "")) } } - -var mathTests = map[string]string{ - "$foo$": `foo`, - "hello $foo$ world": `hello foo world`, - "$$\nfoo\nbar\n$$": `

      foo
      bar
      `, - "`$foo$`": `$foo$`, - "```\n$foo$\n```": `
      $foo$\n
      `, - "~~meow $foo$ asd~~": `meow foo asd`, - "$5 or $10": `$5 or $10`, - "5$ or 10$": `5$ or 10$`, - "$5 or 10$": `5 or 10`, - "$*500*$": `*500*`, - "$$\n*500*\n$$": `
      *500*
      `, - - // TODO: This doesn't work :( - // Maybe same reason as the spoiler wrapping not working? - //"~~$foo$~~": `foo`, -} - -func TestRenderMarkdown_Math(t *testing.T) { - renderer := goldmark.New(goldmark.WithExtensions(extension.Strikethrough, mdext.Math, mdext.EscapeHTML), format.HTMLOptions) - for markdown, html := range mathTests { - rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) - assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n"), "with input %q", markdown) - } -} - -var customEmojiTests = map[string]string{ - `![:meow:](mxc://example.com/emoji.png "Emoji: meow")`: `:meow:`, -} - -func TestRenderMarkdown_CustomEmoji(t *testing.T) { - renderer := goldmark.New(goldmark.WithExtensions(mdext.CustomEmoji), format.HTMLOptions) - for markdown, html := range customEmojiTests { - rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) - assert.Equal(t, html, rendered, "with input %q", markdown) - } -} - -var codeTests = map[string]string{ - "meow": "`meow`", - "me`ow": "``me`ow``", - "`me`ow": "`` `me`ow ``", - "me`ow`": "`` me`ow` ``", - "`meow`": "`` `meow` ``", - "`````````": "`````````` ````````` ``````````", -} - -func TestSafeMarkdownCode(t *testing.T) { - for input, expected := range codeTests { - assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input) - } -} diff --git a/format/mdext/customemoji.go b/format/mdext/customemoji.go deleted file mode 100644 index 2884a5ea..00000000 --- a/format/mdext/customemoji.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mdext - -import ( - "bytes" - - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/ast" - "github.com/yuin/goldmark/renderer" - "github.com/yuin/goldmark/util" -) - -type extCustomEmoji struct{} -type customEmojiRenderer struct { - funcs functionCapturer -} - -// CustomEmoji is an extension that converts certain markdown images into Matrix custom emojis. -var CustomEmoji = &extCustomEmoji{} - -type functionCapturer struct { - renderImage renderer.NodeRendererFunc - renderText renderer.NodeRendererFunc - renderString renderer.NodeRendererFunc -} - -func (fc *functionCapturer) Register(kind ast.NodeKind, rendererFunc renderer.NodeRendererFunc) { - switch kind { - case ast.KindImage: - fc.renderImage = rendererFunc - case ast.KindText: - fc.renderText = rendererFunc - case ast.KindString: - fc.renderString = rendererFunc - } -} - -var ( - _ renderer.NodeRendererFuncRegisterer = (*functionCapturer)(nil) - _ renderer.Option = (*functionCapturer)(nil) -) - -func (fc *functionCapturer) SetConfig(cfg *renderer.Config) { - cfg.NodeRenderers[0].Value.(renderer.NodeRenderer).RegisterFuncs(fc) -} - -func (eeh *extCustomEmoji) Extend(m goldmark.Markdown) { - var fc functionCapturer - m.Renderer().AddOptions(&fc) - m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(&customEmojiRenderer{fc}, 0))) -} - -func (cer *customEmojiRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { - reg.Register(ast.KindImage, cer.renderImage) -} - -var emojiPrefix = []byte("Emoji: ") -var mxcPrefix = []byte("mxc://") - -func (cer *customEmojiRenderer) renderImage(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { - n, ok := node.(*ast.Image) - if ok && entering && bytes.HasPrefix(n.Title, emojiPrefix) && bytes.HasPrefix(n.Destination, mxcPrefix) { - n.Title = bytes.TrimPrefix(n.Title, emojiPrefix) - n.SetAttributeString("data-mx-emoticon", nil) - n.SetAttributeString("height", "32") - } - return cer.funcs.renderImage(w, source, node, entering) -} diff --git a/format/mdext/indentableparagraph.go b/format/mdext/indentableparagraph.go deleted file mode 100644 index a6ebd6c0..00000000 --- a/format/mdext/indentableparagraph.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mdext - -import ( - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/parser" - "github.com/yuin/goldmark/util" -) - -// indentableParagraphParser is the default paragraph parser with CanAcceptIndentedLine. -// Used when disabling CodeBlockParser (as disabling it without a replacement will make indented blocks disappear). -type indentableParagraphParser struct { - parser.BlockParser -} - -var defaultIndentableParagraphParser = &indentableParagraphParser{BlockParser: parser.NewParagraphParser()} - -func (b *indentableParagraphParser) CanAcceptIndentedLine() bool { - return true -} - -// FixIndentedParagraphs is a goldmark option which fixes indented paragraphs when disabling CodeBlockParser. -var FixIndentedParagraphs = goldmark.WithParserOptions(parser.WithBlockParsers(util.Prioritized(defaultIndentableParagraphParser, 500))) diff --git a/format/mdext/math.go b/format/mdext/math.go deleted file mode 100644 index e6a6ecc5..00000000 --- a/format/mdext/math.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mdext - -import ( - "bytes" - "fmt" - stdhtml "html" - "regexp" - "strings" - "unicode" - - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/ast" - "github.com/yuin/goldmark/parser" - "github.com/yuin/goldmark/renderer" - "github.com/yuin/goldmark/renderer/html" - "github.com/yuin/goldmark/text" - "github.com/yuin/goldmark/util" -) - -var astKindMath = ast.NewNodeKind("Math") - -type astMath struct { - ast.BaseInline - value []byte -} - -func (n *astMath) Dump(source []byte, level int) { - ast.DumpHelper(n, source, level, nil, nil) -} - -func (n *astMath) Kind() ast.NodeKind { - return astKindMath -} - -type astMathBlock struct { - ast.BaseBlock -} - -func (n *astMathBlock) Dump(source []byte, level int) { - ast.DumpHelper(n, source, level, nil, nil) -} - -func (n *astMathBlock) Kind() ast.NodeKind { - return astKindMath -} - -type inlineMathParser struct{} - -var defaultInlineMathParser = &inlineMathParser{} - -func NewInlineMathParser() parser.InlineParser { - return defaultInlineMathParser -} - -const mathDelimiter = '$' - -func (s *inlineMathParser) Trigger() []byte { - return []byte{mathDelimiter} -} - -// This ignores lines where there's no space after the closing $ to avoid false positives -var latexInlineRegexp = regexp.MustCompile(`^(\$[^$]*\$)(?:$|\s)`) - -func (s *inlineMathParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { - before := block.PrecendingCharacter() - // Ignore lines where the opening $ comes after a letter or number to avoid false positives - if unicode.IsLetter(before) || unicode.IsNumber(before) { - return nil - } - line, segment := block.PeekLine() - idx := latexInlineRegexp.FindSubmatchIndex(line) - if idx == nil { - return nil - } - block.Advance(idx[3]) - return &astMath{ - value: block.Value(text.NewSegment(segment.Start+1, segment.Start+idx[3]-1)), - } -} - -func (s *inlineMathParser) CloseBlock(parent ast.Node, pc parser.Context) { - // nothing to do -} - -type blockMathParser struct{} - -var defaultBlockMathParser = &blockMathParser{} - -func NewBlockMathParser() parser.BlockParser { - return defaultBlockMathParser -} - -var mathBlockInfoKey = parser.NewContextKey() - -type mathBlockData struct { - indent int - length int - node ast.Node -} - -func (b *blockMathParser) Trigger() []byte { - return []byte{'$'} -} - -func (b *blockMathParser) Open(parent ast.Node, reader text.Reader, pc parser.Context) (ast.Node, parser.State) { - line, _ := reader.PeekLine() - pos := pc.BlockOffset() - if pos < 0 || (line[pos] != mathDelimiter) { - return nil, parser.NoChildren - } - findent := pos - i := pos - for ; i < len(line) && line[i] == mathDelimiter; i++ { - } - oFenceLength := i - pos - if oFenceLength < 2 { - return nil, parser.NoChildren - } - if i < len(line)-1 { - rest := line[i:] - left := util.TrimLeftSpaceLength(rest) - right := util.TrimRightSpaceLength(rest) - if left < len(rest)-right { - value := rest[left : len(rest)-right] - if bytes.IndexByte(value, mathDelimiter) > -1 { - return nil, parser.NoChildren - } - } - } - node := &astMathBlock{} - pc.Set(mathBlockInfoKey, &mathBlockData{findent, oFenceLength, node}) - return node, parser.NoChildren - -} - -func (b *blockMathParser) Continue(node ast.Node, reader text.Reader, pc parser.Context) parser.State { - line, segment := reader.PeekLine() - fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) - - w, pos := util.IndentWidth(line, reader.LineOffset()) - if w < 4 { - i := pos - for ; i < len(line) && line[i] == mathDelimiter; i++ { - } - length := i - pos - if length >= fdata.length && util.IsBlank(line[i:]) { - newline := 1 - if line[len(line)-1] != '\n' { - newline = 0 - } - reader.Advance(segment.Stop - segment.Start - newline + segment.Padding) - return parser.Close - } - } - pos, padding := util.IndentPositionPadding(line, reader.LineOffset(), segment.Padding, fdata.indent) - if pos < 0 { - pos = util.FirstNonSpacePosition(line) - if pos < 0 { - pos = 0 - } - padding = 0 - } - seg := text.NewSegmentPadding(segment.Start+pos, segment.Stop, padding) - seg.ForceNewline = true // EOF as newline - node.Lines().Append(seg) - reader.AdvanceAndSetPadding(segment.Stop-segment.Start-pos-1, padding) - return parser.Continue | parser.NoChildren -} - -func (b *blockMathParser) Close(node ast.Node, reader text.Reader, pc parser.Context) { - fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) - if fdata.node == node { - pc.Set(mathBlockInfoKey, nil) - } -} - -func (b *blockMathParser) CanInterruptParagraph() bool { - return true -} - -func (b *blockMathParser) CanAcceptIndentedLine() bool { - return false -} - -type mathHTMLRenderer struct { - html.Config -} - -func NewMathHTMLRenderer(opts ...html.Option) renderer.NodeRenderer { - r := &mathHTMLRenderer{ - Config: html.NewConfig(), - } - for _, opt := range opts { - opt.SetHTMLOption(&r.Config) - } - return r -} - -func (r *mathHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { - reg.Register(astKindMath, r.renderMath) -} - -func (r *mathHTMLRenderer) renderMath(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) { - if entering { - tag := "span" - var tex string - switch typed := n.(type) { - case *astMathBlock: - tag = "div" - tex = string(n.Lines().Value(source)) - case *astMath: - tex = string(typed.value) - } - tex = stdhtml.EscapeString(strings.TrimSpace(tex)) - _, _ = fmt.Fprintf(w, `<%s data-mx-maths="%s">%s`, tag, tex, strings.ReplaceAll(tex, "\n", "
      "), tag) - } - return ast.WalkSkipChildren, nil -} - -type math struct{} - -// Math is an extension that allow you to use math like '$$text$$'. -var Math = &math{} - -func (e *math) Extend(m goldmark.Markdown) { - m.Parser().AddOptions(parser.WithInlineParsers( - util.Prioritized(NewInlineMathParser(), 500), - ), parser.WithBlockParsers( - util.Prioritized(NewBlockMathParser(), 850), - )) - m.Renderer().AddOptions(renderer.WithNodeRenderers( - util.Prioritized(NewMathHTMLRenderer(), 500), - )) -} diff --git a/go.mod b/go.mod index 49a1d4e4..4ad4143c 100644 --- a/go.mod +++ b/go.mod @@ -1,42 +1,39 @@ module maunium.net/go/mautrix -go 1.25.0 - -toolchain go1.26.0 +go 1.22 require ( - filippo.io/edwards25519 v1.2.0 github.com/chzyer/readline v1.5.1 - github.com/coder/websocket v1.8.14 - github.com/lib/pq v1.11.2 - github.com/mattn/go-sqlite3 v1.14.34 - github.com/rs/xid v1.6.0 - github.com/rs/zerolog v1.34.0 + github.com/gorilla/mux v1.8.0 + github.com/gorilla/websocket v1.5.0 + github.com/lib/pq v1.10.9 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/rs/xid v1.5.0 + github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/stretchr/testify v1.11.1 - github.com/tidwall/gjson v1.18.0 + github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.16 - go.mau.fi/util v0.9.6 - go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.48.0 - golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa - golang.org/x/net v0.50.0 - golang.org/x/sync v0.19.0 + github.com/yuin/goldmark v1.7.4 + go.mau.fi/util v0.7.0 + go.mau.fi/zeroconfig v0.1.3 + golang.org/x/crypto v0.26.0 + golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa + golang.org/x/net v0.28.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) require ( - github.com/coreos/go-systemd/v22 v22.6.0 // indirect + github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.17.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 871a5156..0adc7117 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= -filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= @@ -8,70 +6,66 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= -github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= -github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= -github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= -github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= +github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= -github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= -github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= -github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= -github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= -github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= -go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= -go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= -go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= -golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= +github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +go.mau.fi/util v0.7.0 h1:l31z+ivrSQw+cv/9eFebEqtQW2zhxivGypn+JT0h/ws= +go.mau.fi/util v0.7.0/go.mod h1:bWYreIoTULL/UiRbZdfddPh7uWDFW5yX4YCv5FB0eE0= +go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= +go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/hicli/cryptohelper.go b/hicli/cryptohelper.go new file mode 100644 index 00000000..2a2e9626 --- /dev/null +++ b/hicli/cryptohelper.go @@ -0,0 +1,65 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type hiCryptoHelper HiClient + +var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil) + +func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) { + roomMeta, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return nil, fmt.Errorf("unknown room") + } + return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content) +} + +func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return h.Crypto.DecryptMegolmEvent(ctx, evt) +} + +func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) +} + +func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ + userID: {deviceID}, + h.Account.UserID: {"*"}, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Failed to send room key request") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Sent room key request") + } +} + +func (h *hiCryptoHelper) Init(ctx context.Context) error { + return nil +} diff --git a/hicli/database/account.go b/hicli/database/account.go new file mode 100644 index 00000000..49b50771 --- /dev/null +++ b/hicli/database/account.go @@ -0,0 +1,65 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1` + putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2` + upsertAccountQuery = ` + INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id) + DO UPDATE SET device_id = excluded.device_id, + access_token = excluded.access_token, + homeserver_url = excluded.homeserver_url, + next_batch = excluded.next_batch + ` +) + +type AccountQuery struct { + *dbutil.QueryHelper[*Account] +} + +func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) { + err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID) + return +} + +func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) { + return aq.QueryOne(ctx, getAccountQuery, userID) +} + +func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error { + return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID) +} + +func (aq *AccountQuery) Put(ctx context.Context, account *Account) error { + return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...) +} + +type Account struct { + UserID id.UserID + DeviceID id.DeviceID + AccessToken string + HomeserverURL string + NextBatch string +} + +func (a *Account) Scan(row dbutil.Scannable) (*Account, error) { + return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch)) +} + +func (a *Account) sqlVariables() []any { + return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch} +} diff --git a/hicli/database/accountdata.go b/hicli/database/accountdata.go new file mode 100644 index 00000000..963886c3 --- /dev/null +++ b/hicli/database/accountdata.go @@ -0,0 +1,71 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "encoding/json" + "unsafe" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertAccountDataQuery = ` + INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3) + ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content + ` + upsertRoomAccountDataQuery = ` + INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content + ` +) + +type AccountDataQuery struct { + *dbutil.QueryHelper[*AccountData] +} + +func unsafeJSONString(content json.RawMessage) *string { + if content == nil { + return nil + } + str := unsafe.String(unsafe.SliceData(content), len(content)) + return &str +} + +func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content)) +} + +func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content)) +} + +type AccountData struct { + UserID id.UserID + RoomID id.RoomID + Type string + Content json.RawMessage +} + +func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) { + var roomID sql.NullString + err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content)) + if err != nil { + return nil, err + } + a.RoomID = id.RoomID(roomID.String) + return a, nil +} + +func (a *AccountData) sqlVariables() []any { + return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)} +} diff --git a/hicli/database/database.go b/hicli/database/database.go new file mode 100644 index 00000000..601ca64b --- /dev/null +++ b/hicli/database/database.go @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/hicli/database/upgrades" +) + +type Database struct { + *dbutil.Database + + Account AccountQuery + AccountData AccountDataQuery + Room RoomQuery + Event EventQuery + CurrentState CurrentStateQuery + Timeline TimelineQuery + SessionRequest SessionRequestQuery + Receipt ReceiptQuery +} + +func New(rawDB *dbutil.Database) *Database { + rawDB.UpgradeTable = upgrades.Table + eventQH := dbutil.MakeQueryHelper(rawDB, newEvent) + return &Database{ + Database: rawDB, + + Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)}, + AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)}, + Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)}, + Event: EventQuery{QueryHelper: eventQH}, + CurrentState: CurrentStateQuery{QueryHelper: eventQH}, + Timeline: TimelineQuery{QueryHelper: eventQH}, + SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)}, + Receipt: ReceiptQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newReceipt)}, + } +} + +func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest { + return &SessionRequest{} +} + +func newEvent(_ *dbutil.QueryHelper[*Event]) *Event { + return &Event{} +} + +func newRoom(_ *dbutil.QueryHelper[*Room]) *Room { + return &Room{} +} + +func newReceipt(_ *dbutil.QueryHelper[*Receipt]) *Receipt { + return &Receipt{} +} + +func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData { + return &AccountData{} +} + +func newAccount(_ *dbutil.QueryHelper[*Account]) *Account { + return &Account{} +} diff --git a/hicli/database/event.go b/hicli/database/event.go new file mode 100644 index 00000000..de21e317 --- /dev/null +++ b/hicli/database/event.go @@ -0,0 +1,438 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getEventBaseQuery = ` + SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + FROM event + ` + getEventByRowID = getEventBaseQuery + `WHERE rowid = $1` + getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)` + getEventByID = getEventBaseQuery + `WHERE event_id = $1` + getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` + insertEventBaseQuery = ` + INSERT INTO event ( + room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + ` + insertEventQuery = insertEventBaseQuery + `RETURNING rowid` + upsertEventQuery = insertEventBaseQuery + ` + ON CONFLICT (event_id) DO UPDATE + SET decrypted=COALESCE(event.decrypted, excluded.decrypted), + decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type), + redacted_by=COALESCE(event.redacted_by, excluded.redacted_by), + decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END, + timestamp=excluded.timestamp, + unsigned=COALESCE(excluded.unsigned, event.unsigned) + ON CONFLICT (transaction_id) DO UPDATE + SET event_id=excluded.event_id, + timestamp=excluded.timestamp, + unsigned=excluded.unsigned + RETURNING rowid + ` + updateEventIDQuery = `UPDATE event SET event_id=$2 WHERE rowid=$1` + updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` + getEventReactionsQuery = getEventBaseQuery + ` + WHERE room_id = ? + AND type = 'm.reaction' + AND relation_type = 'm.annotation' + AND redacted_by IS NULL + AND relates_to IN (%s) + ` + getEventEditRowIDsQuery = ` + SELECT main.event_id, edit.rowid + FROM event main + JOIN event edit ON + edit.room_id = main.room_id + AND edit.relates_to = main.event_id + AND edit.relation_type = 'm.replace' + AND edit.type = main.type + AND edit.sender = main.sender + AND edit.redacted_by IS NULL + WHERE main.event_id IN (%s) + ORDER BY main.event_id, edit.timestamp + ` + setLastEditRowIDQuery = ` + UPDATE event SET last_edit_rowid = $2 WHERE event_id = $1 + ` + updateReactionCountsQuery = `UPDATE event SET reactions = $2 WHERE event_id = $1` +) + +type EventQuery struct { + *dbutil.QueryHelper[*Event] +} + +func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) { + return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID) +} + +func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, error) { + return eq.QueryOne(ctx, getEventByID, eventID) +} + +func (eq *EventQuery) GetByRowID(ctx context.Context, rowID EventRowID) (*Event, error) { + return eq.QueryOne(ctx, getEventByRowID, rowID) +} + +func (eq *EventQuery) GetByRowIDs(ctx context.Context, rowIDs ...EventRowID) ([]*Event, error) { + query, params := buildMultiEventGetFunction(nil, rowIDs, getManyEventsByRowID) + return eq.QueryMany(ctx, query, params...) +} + +func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, err error) { + err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID) + if err == nil { + evt.RowID = rowID + } + return +} + +func (eq *EventQuery) Insert(ctx context.Context, evt *Event) (rowID EventRowID, err error) { + err = eq.GetDB().QueryRow(ctx, insertEventQuery, evt.sqlVariables()...).Scan(&rowID) + if err == nil { + evt.RowID = rowID + } + return +} + +func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.EventID) error { + return eq.Exec(ctx, updateEventIDQuery, rowID, newID) +} + +func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error { + return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) +} + +func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error { + eventIDs := make([]id.EventID, 0) + eventMap := make(map[id.EventID]*Event) + for i, evt := range events { + if evt.Reactions == nil { + eventIDs[i] = evt.ID + eventMap[evt.ID] = evt + } + } + result, err := eq.GetReactions(ctx, roomID, eventIDs...) + if err != nil { + return err + } + for evtID, res := range result { + eventMap[evtID].Reactions = res.Counts + } + return nil +} + +func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, events []*Event) error { + eventIDs := make([]id.EventID, 0) + eventMap := make(map[id.EventID]*Event) + for i, evt := range events { + if evt.LastEditRowID == nil { + eventIDs[i] = evt.ID + eventMap[evt.ID] = evt + } + } + return eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + result, err := eq.GetEditRowIDs(ctx, roomID, eventIDs...) + if err != nil { + return err + } + for evtID, res := range result { + lastEditRowID := res[len(res)-1] + eventMap[evtID].LastEditRowID = &lastEditRowID + delete(eventMap, evtID) + err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, lastEditRowID) + if err != nil { + return err + } + } + var zero EventRowID + for evtID, evt := range eventMap { + evt.LastEditRowID = &zero + err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, zero) + if err != nil { + return err + } + } + return nil + }) +} + +var reactionKeyPath = exgjson.Path("m.relates_to", "key") + +type GetReactionsResult struct { + Events []*Event + Counts map[string]int +} + +func buildMultiEventGetFunction[T any](preParams []any, eventIDs []T, query string) (string, []any) { + params := make([]any, len(preParams)+len(eventIDs)) + copy(params, preParams) + for i, evtID := range eventIDs { + params[i+len(preParams)] = evtID + } + placeholders := strings.Repeat("?,", len(eventIDs)) + placeholders = placeholders[:len(placeholders)-1] + return fmt.Sprintf(query, placeholders), params +} + +type editRowIDTuple struct { + eventID id.EventID + editRowID EventRowID +} + +func (eq *EventQuery) GetEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID][]EventRowID, error) { + query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventEditRowIDsQuery) + rows, err := eq.GetDB().Query(ctx, query, params...) + output := make(map[id.EventID][]EventRowID) + return output, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (tuple editRowIDTuple, err error) { + err = row.Scan(&tuple.eventID, &tuple.editRowID) + return + }, err).Iter(func(tuple editRowIDTuple) (bool, error) { + output[tuple.eventID] = append(output[tuple.eventID], tuple.editRowID) + return true, nil + }) +} + +func (eq *EventQuery) GetReactions(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID]*GetReactionsResult, error) { + result := make(map[id.EventID]*GetReactionsResult, len(eventIDs)) + for _, evtID := range eventIDs { + result[evtID] = &GetReactionsResult{Counts: make(map[string]int)} + } + return result, eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventReactionsQuery) + events, err := eq.QueryMany(ctx, query, params...) + if err != nil { + return err + } else if len(events) == 0 { + return nil + } + for _, evt := range events { + dest := result[evt.RelatesTo] + dest.Events = append(dest.Events, evt) + keyRes := gjson.GetBytes(evt.Content, reactionKeyPath) + if keyRes.Type == gjson.String { + dest.Counts[keyRes.Str]++ + } + } + for evtID, res := range result { + if len(res.Counts) > 0 { + err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts}) + if err != nil { + return err + } + } + } + return nil + }) +} + +type EventRowID int64 + +func (m EventRowID) GetMassInsertValues() [1]any { + return [1]any{m} +} + +type Event struct { + RowID EventRowID `json:"rowid"` + TimelineRowID TimelineRowID `json:"timeline_rowid"` + + RoomID id.RoomID `json:"room_id"` + ID id.EventID `json:"event_id"` + Sender id.UserID `json:"sender"` + Type string `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Timestamp time.Time `json:"timestamp"` + + Content json.RawMessage `json:"content"` + Decrypted json.RawMessage `json:"decrypted,omitempty"` + DecryptedType string `json:"decrypted_type,omitempty"` + Unsigned json.RawMessage `json:"unsigned,omitempty"` + + TransactionID string `json:"transaction_id,omitempty"` + + RedactedBy id.EventID `json:"redacted_by,omitempty"` + RelatesTo id.EventID `json:"relates_to,omitempty"` + RelationType event.RelationType `json:"relation_type,omitempty"` + + MegolmSessionID id.SessionID `json:"-,omitempty"` + DecryptionError string `json:"decryption_error,omitempty"` + + Reactions map[string]int `json:"reactions,omitempty"` + LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` +} + +func MautrixToEvent(evt *event.Event) *Event { + dbEvt := &Event{ + RoomID: evt.RoomID, + ID: evt.ID, + Sender: evt.Sender, + Type: evt.Type.Type, + StateKey: evt.StateKey, + Timestamp: time.UnixMilli(evt.Timestamp), + Content: evt.Content.VeryRaw, + MegolmSessionID: getMegolmSessionID(evt), + TransactionID: evt.Unsigned.TransactionID, + } + if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") { + dbEvt.TransactionID = "" + } + dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt) + dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) + if evt.Unsigned.RedactedBecause != nil { + dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID + } + return dbEvt +} + +func (e *Event) AsRawMautrix() *event.Event { + evt := &event.Event{ + RoomID: e.RoomID, + ID: e.ID, + Sender: e.Sender, + Type: event.Type{Type: e.Type, Class: event.MessageEventType}, + StateKey: e.StateKey, + Timestamp: e.Timestamp.UnixMilli(), + Content: event.Content{VeryRaw: e.Content}, + } + if e.Decrypted != nil { + evt.Content.VeryRaw = e.Decrypted + evt.Type.Type = e.DecryptedType + evt.Mautrix.WasEncrypted = true + } + if e.StateKey != nil { + evt.Type.Class = event.StateEventType + } + _ = json.Unmarshal(e.Unsigned, &evt.Unsigned) + return evt +} + +func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { + var timestamp int64 + var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, decryptedType sql.NullString + err := row.Scan( + &e.RowID, + &e.TimelineRowID, + &e.RoomID, + &e.ID, + &e.Sender, + &e.Type, + &e.StateKey, + ×tamp, + (*[]byte)(&e.Content), + (*[]byte)(&e.Decrypted), + &decryptedType, + (*[]byte)(&e.Unsigned), + &transactionID, + &redactedBy, + &relatesTo, + &relationType, + &megolmSessionID, + &decryptionError, + dbutil.JSON{Data: &e.Reactions}, + &e.LastEditRowID, + ) + if err != nil { + return nil, err + } + e.Timestamp = time.UnixMilli(timestamp) + e.TransactionID = transactionID.String + e.RedactedBy = id.EventID(redactedBy.String) + e.RelatesTo = id.EventID(relatesTo.String) + e.RelationType = event.RelationType(relatesTo.String) + e.MegolmSessionID = id.SessionID(megolmSessionID.String) + e.DecryptedType = decryptedType.String + e.DecryptionError = decryptionError.String + return e, nil +} + +var relatesToPath = exgjson.Path("m.relates_to", "event_id") +var relationTypePath = exgjson.Path("m.relates_to", "rel_type") + +func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) { + if evt.StateKey != nil { + return "", "" + } + return GetRelatesToFromBytes(evt.Content.VeryRaw) +} + +func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) { + results := gjson.GetManyBytes(content, relatesToPath, relationTypePath) + if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String { + return id.EventID(results[0].Str), event.RelationType(results[1].Str) + } + return "", "" +} + +func getMegolmSessionID(evt *event.Event) id.SessionID { + if evt.Type != event.EventEncrypted { + return "" + } + res := gjson.GetBytes(evt.Content.VeryRaw, "session_id") + if res.Exists() && res.Type == gjson.String { + return id.SessionID(res.Str) + } + return "" +} + +func (e *Event) sqlVariables() []any { + var reactions any + if e.Reactions != nil { + reactions = e.Reactions + } + return []any{ + e.RoomID, + e.ID, + e.Sender, + e.Type, + e.StateKey, + e.Timestamp.UnixMilli(), + unsafeJSONString(e.Content), + unsafeJSONString(e.Decrypted), + dbutil.StrPtr(e.DecryptedType), + unsafeJSONString(e.Unsigned), + dbutil.StrPtr(e.TransactionID), + dbutil.StrPtr(e.RedactedBy), + dbutil.StrPtr(e.RelatesTo), + dbutil.StrPtr(e.RelationType), + dbutil.StrPtr(e.MegolmSessionID), + dbutil.StrPtr(e.DecryptionError), + dbutil.JSON{Data: reactions}, + e.LastEditRowID, + } +} + +func (e *Event) CanUseForPreview() bool { + return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || + (e.Type == event.EventEncrypted.Type && + (e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) && + e.RelationType != event.RelReplace +} + +func (e *Event) BumpsSortingTimestamp() bool { + return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) && + e.RelationType != event.RelReplace +} diff --git a/hicli/database/receipt.go b/hicli/database/receipt.go new file mode 100644 index 00000000..a3370fba --- /dev/null +++ b/hicli/database/receipt.go @@ -0,0 +1,81 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "time" + + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertReceiptQuery = ` + INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE + SET event_id = excluded.event_id, + timestamp = excluded.timestamp + ` +) + +var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)") + +type ReceiptQuery struct { + *dbutil.QueryHelper[*Receipt] +} + +func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error { + return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...) +} + +func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error { + if len(receipts) > 1000 { + return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + for _, receiptChunk := range exslices.Chunk(receipts, 200) { + err := rq.PutMany(ctx, roomID, receiptChunk...) + if err != nil { + return err + } + } + return nil + }) + } + query, params := receiptMassInserter.Build([1]any{roomID}, receipts) + return rq.Exec(ctx, query, params...) +} + +type Receipt struct { + RoomID id.RoomID + UserID id.UserID + ReceiptType event.ReceiptType + ThreadID event.ThreadID + EventID id.EventID + Timestamp time.Time +} + +func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) { + var ts int64 + err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts) + if err != nil { + return nil, err + } + r.Timestamp = time.UnixMilli(ts) + return r, nil +} + +func (r *Receipt) sqlVariables() []any { + return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} +} + +func (r *Receipt) GetMassInsertValues() [5]any { + return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} +} diff --git a/hicli/database/room.go b/hicli/database/room.go new file mode 100644 index 00000000..92adc279 --- /dev/null +++ b/hicli/database/room.go @@ -0,0 +1,221 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getRoomBaseQuery = ` + SELECT room_id, creation_content, name, name_quality, avatar, topic, canonical_alias, + lazy_load_summary, encryption_event, has_member_list, + preview_event_rowid, sorting_timestamp, prev_batch + FROM room + ` + getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` + ensureRoomExistsQuery = ` + INSERT INTO room (room_id) VALUES ($1) + ON CONFLICT (room_id) DO NOTHING + ` + upsertRoomFromSyncQuery = ` + UPDATE room + SET creation_content = COALESCE(room.creation_content, $2), + name = COALESCE($3, room.name), + name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END, + avatar = COALESCE($5, room.avatar), + topic = COALESCE($6, room.topic), + canonical_alias = COALESCE($7, room.canonical_alias), + lazy_load_summary = COALESCE($8, room.lazy_load_summary), + encryption_event = COALESCE($9, room.encryption_event), + has_member_list = room.has_member_list OR $10, + preview_event_rowid = COALESCE($11, room.preview_event_rowid), + sorting_timestamp = COALESCE($12, room.sorting_timestamp), + prev_batch = COALESCE($13, room.prev_batch) + WHERE room_id = $1 + ` + setRoomPrevBatchQuery = ` + UPDATE room SET prev_batch = $2 WHERE room_id = $1 + ` + updateRoomPreviewIfLaterOnTimelineQuery = ` + UPDATE room + SET preview_event_rowid = $2 + WHERE room_id = $1 + AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1) + > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0) + RETURNING preview_event_rowid + ` +) + +type RoomQuery struct { + *dbutil.QueryHelper[*Room] +} + +func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) { + return rq.QueryOne(ctx, getRoomByIDQuery, roomID) +} + +func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error { + return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...) +} + +func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error { + return rq.Exec(ctx, ensureRoomExistsQuery, roomID) +} + +func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error { + return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) +} + +func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) { + var newPreviewRowID EventRowID + err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } else if err == nil { + previewChanged = newPreviewRowID == rowID + } + return +} + +type NameQuality int + +const ( + NameQualityNil NameQuality = iota + NameQualityParticipants + NameQualityCanonicalAlias + NameQualityExplicit +) + +type Room struct { + ID id.RoomID `json:"room_id"` + CreationContent *event.CreateEventContent `json:"creation_content,omitempty"` + + Name *string `json:"name,omitempty"` + NameQuality NameQuality `json:"name_quality"` + Avatar *id.ContentURI `json:"avatar,omitempty"` + Topic *string `json:"topic,omitempty"` + CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"` + + LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"` + + EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` + HasMemberList bool `json:"has_member_list"` + + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp time.Time `json:"sorting_timestamp"` + + PrevBatch string `json:"prev_batch"` +} + +func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { + if r.Name != nil && r.NameQuality >= other.NameQuality { + other.Name = r.Name + other.NameQuality = r.NameQuality + hasChanges = true + } + if r.Avatar != nil { + other.Avatar = r.Avatar + hasChanges = true + } + if r.Topic != nil { + other.Topic = r.Topic + hasChanges = true + } + if r.CanonicalAlias != nil { + other.CanonicalAlias = r.CanonicalAlias + hasChanges = true + } + if r.LazyLoadSummary != nil { + other.LazyLoadSummary = r.LazyLoadSummary + hasChanges = true + } + if r.EncryptionEvent != nil && other.EncryptionEvent == nil { + other.EncryptionEvent = r.EncryptionEvent + hasChanges = true + } + other.HasMemberList = other.HasMemberList || r.HasMemberList + if r.PreviewEventRowID > other.PreviewEventRowID { + other.PreviewEventRowID = r.PreviewEventRowID + hasChanges = true + } + if r.SortingTimestamp.After(other.SortingTimestamp) { + other.SortingTimestamp = r.SortingTimestamp + hasChanges = true + } + if r.PrevBatch != "" && other.PrevBatch == "" { + other.PrevBatch = r.PrevBatch + hasChanges = true + } + return +} + +func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { + var prevBatch sql.NullString + var previewEventRowID, sortingTimestamp sql.NullInt64 + err := row.Scan( + &r.ID, + dbutil.JSON{Data: &r.CreationContent}, + &r.Name, + &r.NameQuality, + &r.Avatar, + &r.Topic, + &r.CanonicalAlias, + dbutil.JSON{Data: &r.LazyLoadSummary}, + dbutil.JSON{Data: &r.EncryptionEvent}, + &r.HasMemberList, + &previewEventRowID, + &sortingTimestamp, + &prevBatch, + ) + if err != nil { + return nil, err + } + r.PrevBatch = prevBatch.String + r.PreviewEventRowID = EventRowID(previewEventRowID.Int64) + r.SortingTimestamp = time.UnixMilli(sortingTimestamp.Int64) + return r, nil +} + +func (r *Room) sqlVariables() []any { + return []any{ + r.ID, + dbutil.JSONPtr(r.CreationContent), + r.Name, + r.NameQuality, + r.Avatar, + r.Topic, + r.CanonicalAlias, + dbutil.JSONPtr(r.LazyLoadSummary), + dbutil.JSONPtr(r.EncryptionEvent), + r.HasMemberList, + dbutil.NumPtr(r.PreviewEventRowID), + dbutil.UnixMilliPtr(r.SortingTimestamp), + dbutil.StrPtr(r.PrevBatch), + } +} + +func (r *Room) BumpSortingTimestamp(evt *Event) bool { + if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp) { + return false + } + r.SortingTimestamp = evt.Timestamp + now := time.Now() + if r.SortingTimestamp.After(now) { + r.SortingTimestamp = now + } + return true +} diff --git a/hicli/database/sessionrequest.go b/hicli/database/sessionrequest.go new file mode 100644 index 00000000..6690c13f --- /dev/null +++ b/hicli/database/sessionrequest.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + putSessionRequestQueueEntry = ` + INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (session_id) DO UPDATE + SET min_index = MIN(excluded.min_index, session_request.min_index), + backup_checked = excluded.backup_checked OR session_request.backup_checked, + request_sent = excluded.request_sent OR session_request.request_sent + ` + removeSessionRequestQuery = ` + DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2 + ` + getNextSessionsToRequestQuery = ` + SELECT room_id, session_id, sender, min_index, backup_checked, request_sent + FROM session_request + WHERE request_sent = false OR backup_checked = false + ORDER BY backup_checked, rowid + LIMIT $1 + ` +) + +type SessionRequestQuery struct { + *dbutil.QueryHelper[*SessionRequest] +} + +func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) { + return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count) +} + +func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error { + return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex) +} + +func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error { + return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...) +} + +type SessionRequest struct { + RoomID id.RoomID + SessionID id.SessionID + Sender id.UserID + MinIndex uint32 + BackupChecked bool + RequestSent bool +} + +func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) { + return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent)) +} + +func (s *SessionRequest) sqlVariables() []any { + return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent} +} diff --git a/hicli/database/state.go b/hicli/database/state.go new file mode 100644 index 00000000..845de6ed --- /dev/null +++ b/hicli/database/state.go @@ -0,0 +1,47 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + setCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership + ` + getCurrentRoomStateQuery = ` + SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + FROM current_state cs + JOIN event ON cs.event_rowid = event.rowid + WHERE cs.room_id = $1 + ` + getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3` +) + +type CurrentStateQuery struct { + *dbutil.QueryHelper[*Event] +} + +func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { + return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) +} + +func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) { + return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey) +} + +func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) { + return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID) +} diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go new file mode 100644 index 00000000..e8050e93 --- /dev/null +++ b/hicli/database/statestore.go @@ -0,0 +1,163 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + + "go.mau.fi/util/dbutil" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getMembershipQuery = ` + SELECT membership FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2 + ` + getStateEventContentQuery = ` + SELECT event.content FROM current_state cs + LEFT JOIN event ON event.rowid = cs.event_rowid + WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3 + ` + getRoomJoinedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join' + ` + getRoomJoinedOrInvitedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') + ` + isRoomEncryptedQuery = ` + SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 + ` + getRoomEncryptionEventQuery = ` + SELECT room.encryption_event FROM room WHERE room_id = $1 + ` + findSharedRoomsQuery = ` + SELECT room_id FROM current_state + WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join' + ` +) + +type ClientStateStore struct { + *Database +} + +var _ mautrix.StateStore = (*ClientStateStore)(nil) +var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil) +var _ crypto.StateStore = (*ClientStateStore)(nil) + +func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipJoin) +} + +func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin) +} + +func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + var membership event.Membership + err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership) + if errors.Is(err, sql.ErrNoRows) { + err = nil + membership = event.MembershipLeave + } + return slices.Contains(allowedMemberships, membership) +} + +func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + content, err := c.TryGetMember(ctx, roomID, userID) + if content == nil { + content = &event.MemberEventContent{Membership: event.MembershipLeave} + } + return content, err +} + +func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + //TODO implement me + panic("implement me") +} + +func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { + err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) { + err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID). + Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { + // TODO for multiuser support, this might need to filter by the local user's membership + rows, err := c.Query(ctx, findSharedRoomsQuery, userID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() +} + +// Update methods are all intentionally no-ops as the state store wants to have the full event + +func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + return nil +} + +func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + return nil +} + +func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + return nil +} + +func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go new file mode 100644 index 00000000..891f6acb --- /dev/null +++ b/hicli/database/timeline.go @@ -0,0 +1,123 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + "sync" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + clearTimelineQuery = ` + DELETE FROM timeline WHERE room_id = $1 + ` + appendTimelineQuery = ` + INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) RETURNING rowid, event_rowid + ` + prependTimelineQuery = ` + INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) + ` + findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` + getTimelineQuery = ` + SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + FROM timeline + JOIN event ON event.rowid = timeline.event_rowid + WHERE timeline.room_id = $1 AND timeline.rowid < $2 + ORDER BY timeline.rowid DESC + LIMIT $3 + ` +) + +type TimelineRowID int64 + +type TimelineRowTuple struct { + Timeline TimelineRowID `json:"timeline_rowid"` + Event EventRowID `json:"event_rowid"` +} + +var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) { + err = row.Scan(&trt.Timeline, &trt.Event) + return +}) + +func (trt TimelineRowTuple) GetMassInsertValues() [2]any { + return [2]any{trt.Timeline, trt.Event} +} + +var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)") +var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)") + +type TimelineQuery struct { + *dbutil.QueryHelper[*Event] + + minRowID TimelineRowID + minRowIDFound bool + prependLock sync.Mutex +} + +// Clear clears the timeline of a given room. +func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { + return tq.Exec(ctx, clearTimelineQuery, roomID) +} + +func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) { + tq.prependLock.Lock() + defer tq.prependLock.Unlock() + if !tq.minRowIDFound { + err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return + } + if tq.minRowID >= 0 { + // No negative row IDs exist, start at -1 + tq.minRowID = -1 + } else { + // We fetched the lowest row ID, but we want the next available one, so decrement one + tq.minRowID-- + } + } + startFrom = tq.minRowID + tq.minRowID -= TimelineRowID(count) + return +} + +// Prepend adds the given event row IDs to the beginning of the timeline. +// The events must be sorted in reverse chronological order (newest event first). +func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) { + var startFrom TimelineRowID + startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs)) + if err != nil { + return + } + prependEntries = make([]TimelineRowTuple, len(rowIDs)) + for i, rowID := range rowIDs { + prependEntries[i] = TimelineRowTuple{ + Timeline: startFrom - TimelineRowID(i), + Event: rowID, + } + } + query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries) + err = tq.Exec(ctx, query, params...) + return +} + +// Append adds the given event row IDs to the end of the timeline. +func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) { + query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) + return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList() +} + +func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { + return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) +} diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql new file mode 100644 index 00000000..df6499a1 --- /dev/null +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -0,0 +1,227 @@ +-- v0 -> v1: Latest revision +CREATE TABLE account ( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + access_token TEXT NOT NULL, + homeserver_url TEXT NOT NULL, + + next_batch TEXT NOT NULL +) STRICT; + +CREATE TABLE room ( + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, + + name TEXT, + name_quality INTEGER NOT NULL DEFAULT 0, + avatar TEXT, + topic TEXT, + canonical_alias TEXT, + lazy_load_summary TEXT, + + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, + + preview_event_rowid INTEGER, + sorting_timestamp INTEGER, + + prev_batch TEXT, + + CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL +) STRICT; +CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); +CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC); + +CREATE TABLE account_data ( + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, type) +) STRICT; + +CREATE TABLE room_account_data ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, room_id, type), + CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id); + +CREATE TABLE event ( + rowid INTEGER PRIMARY KEY, + + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + sender TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT, + timestamp INTEGER NOT NULL, + + content TEXT NOT NULL, + decrypted TEXT, + decrypted_type TEXT, + unsigned TEXT NOT NULL, + + transaction_id TEXT, + + redacted_by TEXT, + relates_to TEXT, + relation_type TEXT, + + megolm_session_id TEXT, + decryption_error TEXT, + + reactions TEXT, + last_edit_rowid INTEGER, + + CONSTRAINT event_id_unique_key UNIQUE (event_id), + CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id), + CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX event_room_id_idx ON event (room_id); +CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by); +CREATE INDEX event_relates_to_idx ON event (room_id, relates_to); +CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id); + +CREATE TRIGGER event_update_redacted_by + AFTER INSERT + ON event + WHEN NEW.type = 'm.room.redaction' +BEGIN + UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts'; +END; + +CREATE TRIGGER event_update_last_edit_when_redacted + AFTER UPDATE + ON event + WHEN OLD.redacted_by IS NULL + AND NEW.redacted_by IS NOT NULL + AND NEW.relation_type = 'm.replace' +BEGIN + UPDATE event + SET last_edit_rowid = COALESCE( + (SELECT rowid + FROM event edit + WHERE edit.room_id = event.room_id + AND edit.relates_to = event.event_id + AND edit.relation_type = 'm.replace' + AND edit.type = event.type + AND edit.sender = event.sender + AND edit.redacted_by IS NULL + ORDER BY edit.timestamp DESC + LIMIT 1), + 0) + WHERE event_id = NEW.relates_to + AND last_edit_rowid = NEW.rowid; +END; + +CREATE TRIGGER event_insert_update_last_edit + AFTER INSERT + ON event + WHEN NEW.relation_type = 'm.replace' + AND NEW.redacted_by IS NULL +BEGIN + UPDATE event + SET last_edit_rowid = NEW.rowid + WHERE event_id = NEW.relates_to + AND type = NEW.type + AND sender = NEW.sender + AND state_key IS NULL + AND NEW.timestamp > + COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); +END; + +CREATE TRIGGER event_insert_fill_reactions + AFTER INSERT + ON event + WHEN NEW.type = 'm.reaction' + AND NEW.relation_type = 'm.annotation' + AND NEW.redacted_by IS NULL + AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' +BEGIN + UPDATE event + SET reactions=json_set( + reactions, + '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), + coalesce( + reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), + 0 + ) + 1) + WHERE event_id = NEW.relates_to + AND reactions IS NOT NULL; +END; + +CREATE TRIGGER event_redact_fill_reactions + AFTER UPDATE + ON event + WHEN NEW.type = 'm.reaction' + AND NEW.relation_type = 'm.annotation' + AND NEW.redacted_by IS NOT NULL + AND OLD.redacted_by IS NULL + AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' +BEGIN + UPDATE event + SET reactions=json_set( + reactions, + '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), + coalesce( + reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), + 0 + ) - 1) + WHERE event_id = NEW.relates_to + AND reactions IS NOT NULL; +END; + +CREATE TABLE session_request ( + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + sender TEXT NOT NULL, + min_index INTEGER NOT NULL, + backup_checked INTEGER NOT NULL DEFAULT false, + request_sent INTEGER NOT NULL DEFAULT false, + + PRIMARY KEY (session_id), + CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX session_request_room_idx ON session_request (room_id); + +CREATE TABLE timeline ( + rowid INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE, + CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid) +) STRICT; +CREATE INDEX timeline_room_id_idx ON timeline (room_id); + +CREATE TABLE current_state ( + room_id TEXT NOT NULL, + event_type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + membership TEXT, + + PRIMARY KEY (room_id, event_type, state_key), + CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) +) STRICT, WITHOUT ROWID; + +CREATE TABLE receipt ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + thread_id TEXT NOT NULL, + event_id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + + PRIMARY KEY (room_id, user_id, receipt_type, thread_id), + CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE + -- note: there's no foreign key on event ID because receipts could point at events that are too far in history. +) STRICT; diff --git a/event/cmdschema/testdata/data.go b/hicli/database/upgrades/upgrades.go similarity index 54% rename from event/cmdschema/testdata/data.go rename to hicli/database/upgrades/upgrades.go index eceea3d2..9d0bd1a0 100644 --- a/event/cmdschema/testdata/data.go +++ b/hicli/database/upgrades/upgrades.go @@ -1,14 +1,22 @@ -// Copyright (c) 2026 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package testdata +package upgrades import ( "embed" + + "go.mau.fi/util/dbutil" ) -//go:embed * -var FS embed.FS +var Table dbutil.UpgradeTable + +//go:embed *.sql +var upgrades embed.FS + +func init() { + Table.RegisterFS(upgrades) +} diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go new file mode 100644 index 00000000..02466b69 --- /dev/null +++ b/hicli/decryptionqueue.go @@ -0,0 +1,206 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "sync" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) { + data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID) + if err != nil { + return nil, err + } else if data == nil { + return nil, nil + } + decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey) + if err != nil { + return nil, err + } + return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted) +} + +func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) { + log := zerolog.Ctx(ctx) + err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex) + if err != nil { + log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session") + } + events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID) + if err != nil { + log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption") + return + } else if len(events) == 0 { + log.Trace().Msg("No events to retry decryption for") + return + } + decrypted := events[:0] + for _, evt := range events { + if evt.Decrypted != nil { + continue + } + + evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + if err != nil { + log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") + } else { + decrypted = append(decrypted, evt) + } + } + if len(decrypted) > 0 { + previewRowIDChanges := make(map[id.RoomID]database.EventRowID) + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range decrypted { + err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) + if err != nil { + return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) + } + if evt.CanUseForPreview() { + var previewChanged bool + previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID) + if err != nil { + return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err) + } else if previewChanged { + previewRowIDChanges[evt.RoomID] = evt.RowID + } + } + } + return nil + }) + if err != nil { + log.Err(err).Msg("Failed to save decrypted events") + } else { + h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewRowIDs: previewRowIDChanges}) + } + } +} + +func (h *HiClient) WakeupRequestQueue() { + select { + case h.requestQueueWakeup <- struct{}{}: + default: + } +} + +func (h *HiClient) RunRequestQueue(ctx context.Context) { + log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Starting key request queue") + defer func() { + log.Info().Msg("Stopping key request queue") + }() + for { + err := h.FetchKeysForOutdatedUsers(ctx) + if err != nil { + log.Err(err).Msg("Failed to fetch outdated device lists for tracked users") + } + madeRequests, err := h.RequestQueuedSessions(ctx) + if err != nil { + log.Err(err).Msg("Failed to handle session request queue") + } else if madeRequests { + continue + } + select { + case <-ctx.Done(): + return + case <-h.requestQueueWakeup: + } + } +} + +func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) { + defer doneFunc() + log := zerolog.Ctx(ctx) + if !req.BackupChecked { + sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID) + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to fetch session from key backup") + + // TODO should this have retries instead of just storing it's checked? + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup") + } + } else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex { + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup") + } + } else { + log.Debug().Stringer("session_id", req.SessionID). + Msg("Found session with sufficiently low first known index, removing from queue") + err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex()) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue") + } + } + } else { + err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{ + h.Account.UserID: {"*"}, + req.Sender: {"*"}, + }) + //var err error + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to send key request") + } else { + log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request") + req.RequestSent = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request") + } + } + } +} + +const MaxParallelRequests = 5 + +func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) { + sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests) + if err != nil { + return false, fmt.Errorf("failed to get next events to decrypt: %w", err) + } else if len(sessions) == 0 { + return false, nil + } + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, req := range sessions { + go h.requestQueuedSession(ctx, req, wg.Done) + } + wg.Wait() + + return true, err +} + +func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error { + outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx) + if err != nil { + return err + } else if len(outdatedUsers) == 0 { + return nil + } + _, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false) + if err != nil { + return err + } + // TODO backoff for users that fail to be fetched? + return nil +} diff --git a/hicli/events.go b/hicli/events.go new file mode 100644 index 00000000..a30dda8d --- /dev/null +++ b/hicli/events.go @@ -0,0 +1,39 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type SyncRoom struct { + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + Events []*database.Event `json:"events"` + Reset bool `json:"reset"` +} + +type SyncComplete struct { + Rooms map[id.RoomID]*SyncRoom `json:"rooms"` +} + +type EventsDecrypted struct { + PreviewRowIDs map[id.RoomID]database.EventRowID `json:"room_preview_rowids"` + Events []*database.Event `json:"events"` +} + +type Typing struct { + RoomID id.RoomID `json:"room_id"` + event.TypingEventContent +} + +type SendComplete struct { + Event *database.Event `json:"event"` + Error error `json:"error"` +} diff --git a/hicli/hicli.go b/hicli/hicli.go new file mode 100644 index 00000000..7524b6bc --- /dev/null +++ b/hicli/hicli.go @@ -0,0 +1,224 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix. +package hicli + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type HiClient struct { + DB *database.Database + Account *database.Account + Client *mautrix.Client + Crypto *crypto.OlmMachine + CryptoStore *crypto.SQLCryptoStore + ClientStore *database.ClientStateStore + Log zerolog.Logger + + Verified bool + + KeyBackupVersion id.KeyBackupVersion + KeyBackupKey *backup.MegolmBackupKey + + EventHandler func(evt any) + + firstSyncReceived bool + syncingID int + syncLock sync.Mutex + stopSync context.CancelFunc + encryptLock sync.Mutex + + requestQueueWakeup chan struct{} + + paginationInterrupterLock sync.Mutex + paginationInterrupter map[id.RoomID]context.CancelCauseFunc +} + +var ErrTimelineReset = errors.New("got limited timeline sync response") + +func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient { + if cryptoDB == nil { + cryptoDB = rawDB + } + if rawDB.Owner == "" { + rawDB.Owner = "hicli" + rawDB.IgnoreForeignTables = true + } + if rawDB.Log == nil { + rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) + } + db := database.New(rawDB) + c := &HiClient{ + DB: db, + Log: log, + + requestQueueWakeup: make(chan struct{}, 1), + + EventHandler: evtHandler, + } + c.ClientStore = &database.ClientStateStore{Database: db} + c.Client = &mautrix.Client{ + UserAgent: mautrix.DefaultUserAgent, + Client: &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + // This needs to be relatively high to allow initial syncs + ResponseHeaderTimeout: 180 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 180 * time.Second, + }, + Syncer: (*hiSyncer)(c), + Store: (*hiStore)(c), + StateStore: c.ClientStore, + Log: log.With().Str("component", "mautrix client").Logger(), + } + c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) + cryptoLog := log.With().Str("component", "crypto").Logger() + c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore) + c.Crypto.SessionReceived = c.handleReceivedMegolmSession + c.Crypto.DisableRatchetTracking = true + c.Crypto.DisableDecryptKeyFetching = true + c.Client.Crypto = (*hiCryptoHelper)(c) + return c +} + +func (h *HiClient) IsLoggedIn() bool { + return h.Account != nil +} + +func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error { + if expectedAccount != nil && userID != expectedAccount.UserID { + panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID")) + } + err := h.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade hicli db: %w", err) + } + err = h.CryptoStore.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade crypto db: %w", err) + } + account, err := h.DB.Account.Get(ctx, userID) + if err != nil { + return err + } else if account == nil && expectedAccount != nil { + err = h.DB.Account.Put(ctx, expectedAccount) + if err != nil { + return err + } + account = expectedAccount + } else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID { + return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID) + } + if account != nil { + zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials") + h.Account = account + h.CryptoStore.AccountID = account.UserID.String() + h.CryptoStore.DeviceID = account.DeviceID + h.Client.UserID = account.UserID + h.Client.DeviceID = account.DeviceID + h.Client.AccessToken = account.AccessToken + h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL) + if err != nil { + return err + } + err = h.CheckServerVersions(ctx) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + + h.Verified, err = h.checkIsCurrentDeviceVerified(ctx) + if err != nil { + return err + } + zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status") + if h.Verified { + err = h.loadPrivateKeys(ctx) + if err != nil { + return err + } + go h.Sync() + go h.RunRequestQueue(ctx) + } + } + return nil +} + +var ErrFailedToCheckServerVersions = errors.New("failed to check server versions") +var ErrOutdatedServer = errors.New("homeserver is outdated") +var MinimumSpecVersion = mautrix.SpecV11 + +func (h *HiClient) CheckServerVersions(ctx context.Context) error { + versions, err := h.Client.Versions(ctx) + if err != nil { + return exerrors.NewDualError(ErrFailedToCheckServerVersions, err) + } else if !versions.Contains(MinimumSpecVersion) { + return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest()) + } + return nil +} + +func (h *HiClient) Sync() { + h.Client.StopSync() + if fn := h.stopSync; fn != nil { + fn() + } + h.syncLock.Lock() + defer h.syncLock.Unlock() + h.syncingID++ + syncingID := h.syncingID + log := h.Log.With(). + Str("action", "sync"). + Int("sync_id", syncingID). + Logger() + ctx, cancel := context.WithCancel(log.WithContext(context.Background())) + h.stopSync = cancel + log.Info().Msg("Starting syncing") + err := h.Client.SyncWithContext(ctx) + if err != nil && ctx.Err() == nil { + log.Err(err).Msg("Fatal error in syncer") + } else { + log.Info().Msg("Syncing stopped") + } +} + +func (h *HiClient) Stop() { + h.Client.StopSync() + if fn := h.stopSync; fn != nil { + fn() + } + h.syncLock.Lock() + h.syncLock.Unlock() + err := h.DB.Close() + if err != nil { + h.Log.Err(err).Msg("Failed to close database cleanly") + } +} diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go new file mode 100644 index 00000000..c6873bac --- /dev/null +++ b/hicli/hitest/hitest.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package main + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/chzyer/readline" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exzerolog" + "go.mau.fi/zeroconfig" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli" + "maunium.net/go/mautrix/id" +) + +var writerTypeReadline zeroconfig.WriterType = "hitest_readline" + +func main() { + hicli.InitialDeviceDisplayName = "mautrix hitest" + rl := exerrors.Must(readline.New("> ")) + defer func() { + _ = rl.Close() + }() + zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) { + return rl.Stdout(), nil + }) + debug := zerolog.DebugLevel + log := exerrors.Must((&zeroconfig.Config{ + MinLevel: &debug, + Writers: []zeroconfig.WriterConfig{{ + Type: writerTypeReadline, + Format: zeroconfig.LogFormatPrettyColored, + }}, + }).Compile()) + exzerolog.SetupDefaults(log) + + rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) + ctx := log.WithContext(context.Background()) + cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) { + _, _ = fmt.Fprintf(rl, "Received event of type %T\n", a) + switch evt := a.(type) { + case *hicli.SyncComplete: + for _, room := range evt.Rooms { + name := "name unset" + if room.Meta.Name != nil { + name = *room.Meta.Name + } + _, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID) + _, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp) + _, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset) + } + case *hicli.EventsDecrypted: + for _, decrypted := range evt.Events { + _, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted) + } + if len(evt.PreviewRowIDs) > 0 { + _, _ = fmt.Fprintf(rl, "Room previews updated: %+v\n", evt.PreviewRowIDs) + } + case *hicli.Typing: + _, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs) + } + }) + userID, _ := cli.DB.Account.GetFirstUserID(ctx) + exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil)) + if !cli.IsLoggedIn() { + rl.SetPrompt("User ID: ") + userID := id.UserID(exerrors.Must(rl.Readline())) + _, serverName := exerrors.Must2(userID.Parse()) + discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName)) + password := exerrors.Must(rl.ReadPassword("Password: ")) + recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: ")) + exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode))) + } + rl.SetPrompt("> ") + + for { + line, err := rl.Readline() + if err != nil { + break + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + switch strings.ToLower(fields[0]) { + case "send": + resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{ + Body: strings.Join(fields[2:], " "), + MsgType: event.MsgText, + }) + _, _ = fmt.Fprintln(rl, err) + _, _ = fmt.Fprintf(rl, "%+v\n", resp) + } + } + cli.Stop() +} diff --git a/hicli/login.go b/hicli/login.go new file mode 100644 index 00000000..d33ea422 --- /dev/null +++ b/hicli/login.go @@ -0,0 +1,88 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "net/url" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +var InitialDeviceDisplayName = "mautrix hiclient" + +func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error { + var err error + h.Client.HomeserverURL, err = url.Parse(homeserverURL) + if err != nil { + return err + } + return h.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: username, + }, + Password: password, + InitialDeviceDisplayName: InitialDeviceDisplayName, + }) +} + +func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { + err := h.CheckServerVersions(ctx) + if err != nil { + return err + } + req.StoreCredentials = true + req.StoreHomeserverURL = true + resp, err := h.Client.Login(ctx, req) + if err != nil { + return err + } + h.Account = &database.Account{ + UserID: resp.UserID, + DeviceID: resp.DeviceID, + AccessToken: resp.AccessToken, + HomeserverURL: h.Client.HomeserverURL.String(), + } + h.CryptoStore.AccountID = resp.UserID.String() + h.CryptoStore.DeviceID = resp.DeviceID + err = h.DB.Account.Put(ctx, h.Account) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + err = h.Crypto.ShareKeys(ctx, 0) + if err != nil { + return err + } + _, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true) + if err != nil { + return fmt.Errorf("failed to fetch own devices: %w", err) + } + return nil +} + +func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryCode string) error { + err := h.LoginPassword(ctx, homeserverURL, username, password) + if err != nil { + return err + } + err = h.VerifyWithRecoveryCode(ctx, recoveryCode) + if err != nil { + return err + } + go h.Sync() + go h.RunRequestQueue(ctx) + return nil +} diff --git a/hicli/paginate.go b/hicli/paginate.go new file mode 100644 index 00000000..9992b36e --- /dev/null +++ b/hicli/paginate.go @@ -0,0 +1,141 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress") + +func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) { + events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...) + if err != nil { + return nil, err + } else if len(events) == 0 { + return events, nil + } + firstRoomID := events[0].RoomID + allInSameRoom := true + for _, evt := range events { + if evt.RoomID != firstRoomID { + allInSameRoom = false + break + } + } + if allInSameRoom { + err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events) + if err != nil { + return events, fmt.Errorf("failed to fill last edit row IDs: %w", err) + } + err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events) + if err != nil { + return events, fmt.Errorf("failed to fill reaction counts: %w", err) + } + } else { + // TODO slow path where events are collected and filling is done one room at a time? + } + return events, nil +} + +func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) { + if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil { + return nil, fmt.Errorf("failed to get event from database: %w", err) + } else if evt != nil { + return evt, nil + } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil { + return nil, fmt.Errorf("failed to get event from server: %w", err) + } else { + return h.processEvent(ctx, serverEvt, nil, false) + } +} + +func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { + evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) + if err != nil { + return nil, err + } else if len(evts) > 0 { + return evts, nil + } else { + return h.PaginateServer(ctx, roomID, limit) + } +} + +func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) ([]*database.Event, error) { + ctx, cancel := context.WithCancelCause(ctx) + h.paginationInterrupterLock.Lock() + if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating { + h.paginationInterrupterLock.Unlock() + return nil, ErrPaginationAlreadyInProgress + } + h.paginationInterrupter[roomID] = cancel + h.paginationInterrupterLock.Unlock() + defer func() { + h.paginationInterrupterLock.Lock() + delete(h.paginationInterrupter, roomID) + h.paginationInterrupterLock.Unlock() + }() + + room, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room from database: %w", err) + } + resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit) + if err != nil { + return nil, fmt.Errorf("failed to get messages from server: %w", err) + } + events := make([]*database.Event, len(resp.Chunk)) + wakeupSessionRequests := false + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + if err = ctx.Err(); err != nil { + return err + } + eventRowIDs := make([]database.EventRowID, len(resp.Chunk)) + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + for i, evt := range resp.Chunk { + events[i], err = h.processEvent(ctx, evt, decryptionQueue, true) + if err != nil { + return err + } + eventRowIDs[i] = events[i].RowID + } + wakeupSessionRequests = len(decryptionQueue) > 0 + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events) + if err != nil { + return fmt.Errorf("failed to fill last edit row IDs: %w", err) + } + err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) + if err != nil { + return fmt.Errorf("failed to set prev_batch: %w", err) + } + var tuples []database.TimelineRowTuple + tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs) + if err != nil { + return fmt.Errorf("failed to prepend events to timeline: %w", err) + } + for i, evt := range events { + evt.TimelineRowID = tuples[i].Timeline + } + return nil + }) + if err == nil && wakeupSessionRequests { + h.WakeupRequestQueue() + } + return events, err +} diff --git a/hicli/send.go b/hicli/send.go new file mode 100644 index 00000000..66175e75 --- /dev/null +++ b/hicli/send.go @@ -0,0 +1,207 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { + roomMeta, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return nil, fmt.Errorf("unknown room") + } + var decryptedType event.Type + var decryptedContent json.RawMessage + var megolmSessionID id.SessionID + if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction { + decryptedType = evtType + decryptedContent, err = json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content) + if err != nil { + return nil, fmt.Errorf("failed to encrypt event: %w", err) + } + megolmSessionID = encryptedContent.SessionID + content = encryptedContent + evtType = event.EventEncrypted + } + mainContent, err := json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + var zero database.EventRowID + txnID := "hicli-" + h.Client.TxnID() + relatesTo, relationType := database.GetRelatesToFromBytes(mainContent) + dbEvt := &database.Event{ + RoomID: roomID, + ID: id.EventID(fmt.Sprintf("~%s", txnID)), + Sender: h.Account.UserID, + Type: evtType.Type, + Timestamp: time.Now(), + Content: mainContent, + Decrypted: decryptedContent, + DecryptedType: decryptedType.Type, + Unsigned: []byte("{}"), + TransactionID: txnID, + RelatesTo: relatesTo, + RelationType: relationType, + MegolmSessionID: megolmSessionID, + DecryptionError: "", + Reactions: map[string]int{}, + LastEditRowID: &zero, + } + _, err = h.DB.Event.Insert(ctx, dbEvt) + if err != nil { + return nil, fmt.Errorf("failed to insert event into database: %w", err) + } + go func() { + var err error + defer func() { + h.EventHandler(&SendComplete{ + Event: dbEvt, + Error: err, + }) + }() + var resp *mautrix.RespSendEvent + resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{ + Timestamp: dbEvt.Timestamp.UnixMilli(), + TransactionID: txnID, + DontEncrypt: true, + }) + if err != nil { + // TODO save send error to db? + err = fmt.Errorf("failed to send event: %w", err) + return + } + dbEvt.ID = resp.EventID + err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID) + if err != nil { + err = fmt.Errorf("failed to update event ID in database: %w", err) + } + }() + return dbEvt, nil +} + +func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content) + if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) { + if err = h.shareGroupSession(ctx, room); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + return +} + +func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil { + return fmt.Errorf("failed to get previous outbound group session: %w", err) + } else if session != nil && session.Shared && !session.Expired() { + return nil + } else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil { + return fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return fmt.Errorf("unknown room") + } else { + return h.shareGroupSession(ctx, roomMeta) + } +} + +func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { + if room.HasMemberList { + return nil + } + resp, err := h.Client.Members(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range resp.Chunk { + dbEvt, err := h.processEvent(ctx, evt, nil, true) + if err != nil { + return err + } + membership := event.Membership(evt.Content.Raw["membership"].(string)) + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) + if err != nil { + return err + } + } + return h.DB.Room.Upsert(ctx, &database.Room{ + ID: room.ID, + HasMemberList: true, + }) + }) + if err != nil { + return fmt.Errorf("failed to process room member list: %w", err) + } + return nil +} + +func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error { + err := h.loadMembers(ctx, room) + if err != nil { + return err + } + shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID) + var users []id.UserID + if shareToInvited { + users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID) + } else { + users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID) + } + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil { + return fmt.Errorf("failed to share group session: %w", err) + } + return nil +} + +func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool { + historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event") + return false + } + mautrixEvt := historyVisibility.AsRawMautrix() + err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event") + return false + } + hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent) + if !ok { + zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event") + return false + } + return hv.HistoryVisibility == event.HistoryVisibilityInvited || + hv.HistoryVisibility == event.HistoryVisibilityShared || + hv.HistoryVisibility == event.HistoryVisibilityWorldReadable +} diff --git a/hicli/sync.go b/hicli/sync.go new file mode 100644 index 00000000..c3f30a72 --- /dev/null +++ b/hicli/sync.go @@ -0,0 +1,556 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exzerolog" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type syncContext struct { + shouldWakeupRequestQueue bool + + evt *SyncComplete +} + +func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + log := zerolog.Ctx(ctx) + postponedToDevices := resp.ToDevice.Events[:0] + for _, evt := range resp.ToDevice.Events { + evt.Type.Class = event.ToDeviceEventType + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + log.Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("sender", evt.Sender). + Msg("Failed to parse to-device event, skipping") + continue + } + + switch content := evt.Content.Parsed.(type) { + case *event.EncryptedEventContent: + h.Crypto.HandleEncryptedEvent(ctx, evt) + case *event.RoomKeyWithheldEventContent: + h.Crypto.HandleRoomKeyWithheld(ctx, content) + default: + postponedToDevices = append(postponedToDevices, evt) + } + } + resp.ToDevice.Events = postponedToDevices + + return nil +} + +func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { + h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + go h.asyncPostProcessSyncResponse(ctx, resp, since) + syncCtx := ctx.Value(syncContextKey).(*syncContext) + if syncCtx.shouldWakeupRequestQueue { + h.WakeupRequestQueue() + } + h.firstSyncReceived = true + h.EventHandler(syncCtx.evt) +} + +func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { + for _, evt := range resp.ToDevice.Events { + switch content := evt.Content.Parsed.(type) { + case *event.SecretRequestEventContent: + h.Crypto.HandleSecretRequest(ctx, evt.Sender, content) + case *event.RoomKeyRequestEventContent: + h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content) + } + } +} + +func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + if len(resp.DeviceLists.Changed) > 0 { + zerolog.Ctx(ctx).Debug(). + Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)). + Msg("Marking changed device lists for tracked users as outdated") + err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed) + if err != nil { + return fmt.Errorf("failed to mark changed device lists as outdated: %w", err) + } + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + + for _, evt := range resp.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + for roomID, room := range resp.Rooms.Join { + err := h.processSyncJoinedRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process joined room %s: %w", roomID, err) + } + } + for roomID, room := range resp.Rooms.Leave { + err := h.processSyncLeftRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process left room %s: %w", roomID, err) + } + } + h.Account.NextBatch = resp.NextBatch + err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch) + if err != nil { + return fmt.Errorf("failed to save next_batch: %w", err) + } + return nil +} + +func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { + receiptList := make([]*database.Receipt, 0) + for eventID, receipts := range *content { + for receiptType, users := range receipts { + for userID, receiptInfo := range users { + receiptList = append(receiptList, &database.Receipt{ + UserID: userID, + ReceiptType: receiptType, + ThreadID: receiptInfo.ThreadID, + EventID: eventID, + Timestamp: receiptInfo.Timestamp, + }) + } + } + } + return receiptList +} + +func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + err = h.DB.Room.CreateRow(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to ensure room row exists: %w", err) + } + existingRoomData = &database.Room{ID: roomID} + } + + for _, evt := range room.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + evt.RoomID = roomID + err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) + if err != nil { + return err + } + for _, evt := range room.Ephemeral.Events { + evt.Type.Class = event.EphemeralEventType + err = evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content") + continue + } + switch evt.Type { + case event.EphemeralEventReceipt: + err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...) + if err != nil { + return fmt.Errorf("failed to save receipts: %w", err) + } + case event.EphemeralEventTyping: + go h.EventHandler(&Typing{ + RoomID: roomID, + TypingEventContent: *evt.Content.AsTyping(), + }) + } + if evt.Type != event.EphemeralEventReceipt { + continue + } + } + return nil +} + +func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + return nil + } + return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) +} + +func isDecryptionErrorRetryable(err error) bool { + return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld) +} + +func removeReplyFallback(evt *event.Event) []byte { + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if ok && content.RelatesTo.GetReplyTo() != "" { + prevFormattedBody := content.FormattedBody + content.RemoveReplyFallback() + if content.FormattedBody != prevFormattedBody { + bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody) + bytes, err2 := sjson.SetBytes(bytes, "body", content.Body) + if err == nil && err2 == nil { + return bytes + } + } + } + return nil +} + +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) { + err := evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + return nil, "", err + } + decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) + if err != nil { + return nil, "", err + } + withoutFallback := removeReplyFallback(decrypted) + if withoutFallback != nil { + return withoutFallback, decrypted.Type.Type, nil + } + return decrypted.Content.VeryRaw, decrypted.Type.Type, nil +} + +func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) { + if checkDB { + dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID) + if err != nil { + return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err) + } else if dbEvt != nil { + return dbEvt, nil + } + } + dbEvt := database.MautrixToEvent(evt) + contentWithoutFallback := removeReplyFallback(evt) + if contentWithoutFallback != nil { + dbEvt.Content = contentWithoutFallback + } + var decryptionErr error + if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { + dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + if decryptionErr != nil { + dbEvt.DecryptionError = decryptionErr.Error() + } + } else if evt.Type == event.EventRedaction { + if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() { + var err error + evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts) + if err != nil { + return dbEvt, fmt.Errorf("failed to set redacts field: %w", err) + } + } + } + _, err := h.DB.Event.Upsert(ctx, dbEvt) + if err != nil { + return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) + } + if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { + req, ok := decryptionQueue[dbEvt.MegolmSessionID] + if !ok { + req = &database.SessionRequest{ + RoomID: evt.RoomID, + SessionID: dbEvt.MegolmSessionID, + Sender: evt.Sender, + } + } + minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) + req.MinIndex = min(uint32(minIndex), req.MinIndex) + decryptionQueue[dbEvt.MegolmSessionID] = req + } + return dbEvt, err +} + +func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { + updatedRoom := &database.Room{ + ID: room.ID, + + SortingTimestamp: room.SortingTimestamp, + NameQuality: room.NameQuality, + } + heroesChanged := false + if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { + summary = room.LazyLoadSummary + } else if room.LazyLoadSummary == nil || + !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || + !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || + !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) { + updatedRoom.LazyLoadSummary = summary + heroesChanged = true + } + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) + processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { + evt.RoomID = room.ID + dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false) + if err != nil { + return -1, err + } + if isTimeline { + if dbEvt.CanUseForPreview() { + updatedRoom.PreviewEventRowID = dbEvt.RowID + } + updatedRoom.BumpSortingTimestamp(dbEvt) + } + if evt.StateKey != nil { + var membership event.Membership + if evt.Type == event.StateMember { + membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) + if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) { + heroesChanged = true + } + } else if evt.Type == event.StateElementFunctionalMembers { + heroesChanged = true + } + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) + if err != nil { + return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err) + } + processImportantEvent(ctx, evt, room, updatedRoom) + } + allNewEvents = append(allNewEvents, dbEvt) + return dbEvt.RowID, nil + } + var err error + for _, evt := range state.Events { + evt.Type.Class = event.StateEventType + _, err = processNewEvent(evt, false) + if err != nil { + return err + } + } + var timelineRowTuples []database.TimelineRowTuple + if len(timeline.Events) > 0 { + timelineIDs := make([]database.EventRowID, len(timeline.Events)) + for i, evt := range timeline.Events { + if evt.StateKey != nil { + evt.Type.Class = event.StateEventType + } else { + evt.Type.Class = event.MessageEventType + } + timelineIDs[i], err = processNewEvent(evt, true) + if err != nil { + return err + } + } + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + if len(decryptionQueue) > 0 { + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + if timeline.Limited { + err = h.DB.Timeline.Clear(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to clear old timeline: %w", err) + } + updatedRoom.PrevBatch = timeline.PrevBatch + h.paginationInterrupterLock.Lock() + if interrupt, ok := h.paginationInterrupter[room.ID]; ok { + interrupt(ErrTimelineReset) + } + h.paginationInterrupterLock.Unlock() + } + timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) + if err != nil { + return fmt.Errorf("failed to append timeline: %w", err) + } + } + // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset + if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { + name, err := h.calculateRoomParticipantName(ctx, room.ID, summary) + if err != nil { + return fmt.Errorf("failed to calculate room name: %w", err) + } + updatedRoom.Name = &name + updatedRoom.NameQuality = database.NameQualityParticipants + } + if timeline.PrevBatch != "" && room.PrevBatch == "" { + updatedRoom.PrevBatch = timeline.PrevBatch + } + roomChanged := updatedRoom.CheckChangesAndCopyInto(room) + if roomChanged { + err = h.DB.Room.Upsert(ctx, updatedRoom) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 { + ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ + Meta: room, + Timeline: timelineRowTuples, + Reset: timeline.Limited, + Events: allNewEvents, + } + } + return nil +} + +func joinMemberNames(names []string, totalCount int) string { + if len(names) == 1 { + return names[0] + } else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) { + return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1] + } else { + return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5) + } +} + +func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, error) { + if summary == nil || len(summary.Heroes) == 0 { + return "Empty room", nil + } + var functionalMembers []id.UserID + functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "") + if err != nil { + return "", fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err) + } else if functionalMembersEvt != nil { + mautrixEvt := functionalMembersEvt.AsRawMautrix() + _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) + content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent) + if ok { + functionalMembers = content.ServiceMembers + } + } + var members, leftMembers []string + var memberCount int + if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 { + memberCount = *summary.JoinedMemberCount + } else if summary.InvitedMemberCount != nil { + memberCount = *summary.InvitedMemberCount + } + for _, hero := range summary.Heroes { + if slices.Contains(functionalMembers, hero) { + memberCount-- + continue + } else if len(members) >= 5 { + break + } + heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String()) + if err != nil { + return "", fmt.Errorf("failed to get %s's member event: %w", hero, err) + } + results := gjson.GetManyBytes(heroEvt.Content, "membership", "displayname") + name := results[1].Str + if name == "" { + name = hero.String() + } + if results[0].Str == "join" || results[0].Str == "invite" { + members = append(members, name) + } else { + leftMembers = append(leftMembers, name) + } + } + if len(members) > 0 { + return joinMemberNames(members, memberCount), nil + } else if len(leftMembers) > 0 { + return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), nil + } else { + return "Empty room", nil + } +} + +func intPtrEqual(a, b *int) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + +func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) { + if evt.StateKey == nil { + return + } + switch evt.Type { + case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: + if *evt.StateKey != "" { + return + } + default: + return + } + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("event_id", evt.ID). + Msg("Failed to parse state event, skipping") + return + } + switch evt.Type { + case event.StateCreate: + updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) + case event.StateEncryption: + newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent) + if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm { + updatedRoom.EncryptionEvent = newEncryption + } + case event.StateRoomName: + content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) + if ok { + updatedRoom.Name = &content.Name + updatedRoom.NameQuality = database.NameQualityExplicit + if content.Name == "" { + if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" { + updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + } else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" { + updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + } else { + updatedRoom.NameQuality = database.NameQualityNil + } + } + } + case event.StateCanonicalAlias: + content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent) + if ok { + updatedRoom.CanonicalAlias = &content.Alias + if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias { + updatedRoom.Name = (*string)(&content.Alias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + if content.Alias == "" { + updatedRoom.NameQuality = database.NameQualityNil + } + } + } + case event.StateRoomAvatar: + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if ok { + url, _ := content.URL.Parse() + updatedRoom.Avatar = &url + } + case event.StateTopic: + content, ok := evt.Content.Parsed.(*event.TopicEventContent) + if ok { + updatedRoom.Topic = &content.Topic + } + } + return +} diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go new file mode 100644 index 00000000..13837202 --- /dev/null +++ b/hicli/syncwrap.go @@ -0,0 +1,96 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type hiSyncer HiClient + +var _ mautrix.Syncer = (*hiSyncer)(nil) + +type contextKey int + +const ( + syncContextKey contextKey = iota +) + +func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + c := (*HiClient)(h) + ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}}) + err := c.preProcessSyncResponse(ctx, resp, since) + if err != nil { + return err + } + err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + return c.processSyncResponse(ctx, resp, since) + }) + if err != nil { + return err + } + c.postProcessSyncResponse(ctx, resp, since) + return nil +} + +func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + (*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second") + return 1 * time.Second, nil +} + +func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + if !h.Verified { + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + NotRooms: []id.RoomID{"*"}, + }, + } + } + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + State: mautrix.FilterPart{ + LazyLoadMembers: true, + }, + Timeline: mautrix.FilterPart{ + Limit: 100, + LazyLoadMembers: true, + }, + }, + } +} + +type hiStore HiClient + +var _ mautrix.SyncStore = (*hiStore)(nil) + +// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing + +func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil } +func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil } + +func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { + // This is intentionally a no-op: we don't want to save the next batch before processing the sync + return nil +} + +func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) { + if h.Account.UserID != userID { + return "", fmt.Errorf("mismatching user ID") + } + return h.Account.NextBatch, nil +} diff --git a/hicli/verify.go b/hicli/verify.go new file mode 100644 index 00000000..905be052 --- /dev/null +++ b/hicli/verify.go @@ -0,0 +1,158 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/crypto/ssss" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) { + keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx) + if keys == nil { + return false, fmt.Errorf("own cross-signing keys not found") + } + isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey) + if err != nil { + return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + return isVerified, nil +} + +func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error { + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey) + if err != nil { + return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err) + } + key, err := backup.MegolmBackupKeyFromBytes(data) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes())) + if err != nil { + return fmt.Errorf("failed to store megolm backup key: %w", err) + } + h.KeyBackupKey = key + return nil +} + +func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) { + secretData, err := h.CryptoStore.GetSecret(ctx, secret) + if err != nil { + return nil, fmt.Errorf("failed to get secret %s: %w", secret, err) + } + data, err := base64.StdEncoding.DecodeString(secretData) + if err != nil { + return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err) + } + return data, nil +} + +func (h *HiClient) loadPrivateKeys(ctx context.Context) error { + zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys") + masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster) + if err != nil { + return fmt.Errorf("failed to get master key: %w", err) + } + selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning) + if err != nil { + return fmt.Errorf("failed to get self-signing key: %w", err) + } + userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning) + if err != nil { + return fmt.Errorf("failed to get user signing key: %w", err) + } + err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{ + MasterKey: masterKeySeed, + SelfSigningKey: selfSigningKeySeed, + UserSigningKey: userSigningKeySeed, + }) + if err != nil { + return fmt.Errorf("failed to import cross-signing private keys: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Loading key backup key") + keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1) + if err != nil { + return fmt.Errorf("failed to get megolm backup key: %w", err) + } + h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version") + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + zerolog.Ctx(ctx).Debug().Msg("Secrets loaded") + return nil +} + +func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { + keys := h.Crypto.CrossSigningKeys + err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed())) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error { + keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(keyID, code) + if err != nil { + return err + } + err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = h.Crypto.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + err = h.storeCrossSigningPrivateKeys(ctx) + if err != nil { + return fmt.Errorf("failed to store cross-signing private keys: %w", err) + } + err = h.fetchKeyBackupKey(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch key backup key: %w", err) + } + h.Verified = true + return nil +} diff --git a/id/contenturi.go b/id/contenturi.go index 67127b6c..e6a313f5 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -17,14 +17,8 @@ import ( ) var ( - ErrInvalidContentURI = errors.New("invalid Matrix content URI") - ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") -) - -// Deprecated: use variables prefixed with Err -var ( - InvalidContentURI = ErrInvalidContentURI - InputNotJSONString = ErrInputNotJSONString + InvalidContentURI = errors.New("invalid Matrix content URI") + InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -61,9 +55,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = ErrInvalidContentURI + err = InvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = ErrInvalidContentURI + err = InvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -77,9 +71,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = ErrInvalidContentURI + err = InvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = ErrInvalidContentURI + err = InvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -92,7 +86,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString) + return InputNotJSONString } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { diff --git a/id/crypto.go b/id/crypto.go index ee857f78..355a84a8 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -53,34 +53,6 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) -type KeySource string - -func (source KeySource) String() string { - return string(source) -} - -func (source KeySource) Int() int { - switch source { - case KeySourceDirect: - return 100 - case KeySourceBackup: - return 90 - case KeySourceImport: - return 80 - case KeySourceForward: - return 50 - default: - return 0 - } -} - -const ( - KeySourceDirect KeySource = "direct" - KeySourceBackup KeySource = "backup" - KeySourceImport KeySource = "import" - KeySourceForward KeySource = "forward" -) - // BackupVersion is an arbitrary string that identifies a server side key backup. type KeyBackupVersion string diff --git a/id/matrixuri.go b/id/matrixuri.go index d5c78bc7..5ec403e9 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if len(uri.Via) > 0 { + if uri.Via != nil && len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { @@ -65,9 +65,6 @@ func (uri *MatrixURI) getQuery() url.Values { // String converts the parsed matrix: URI back into the string representation. func (uri *MatrixURI) String() string { - if uri == nil { - return "" - } parts := []string{ SigilToPathSegment[uri.Sigil1], url.PathEscape(uri.MXID1), @@ -84,9 +81,6 @@ func (uri *MatrixURI) String() string { // MatrixToURL converts to parsed matrix: URI into a matrix.to URL func (uri *MatrixURI) MatrixToURL() string { - if uri == nil { - return "" - } fragment := fmt.Sprintf("#/%s", url.PathEscape(uri.PrimaryIdentifier())) if uri.Sigil2 != 0 { fragment = fmt.Sprintf("%s/%s", fragment, url.PathEscape(uri.SecondaryIdentifier())) @@ -102,16 +96,13 @@ func (uri *MatrixURI) MatrixToURL() string { // PrimaryIdentifier returns the first Matrix identifier in the URI. // Currently room IDs, room aliases and user IDs can be in the primary identifier slot. func (uri *MatrixURI) PrimaryIdentifier() string { - if uri == nil { - return "" - } return fmt.Sprintf("%c%s", uri.Sigil1, uri.MXID1) } // SecondaryIdentifier returns the second Matrix identifier in the URI. // Currently only event IDs can be in the secondary identifier slot. func (uri *MatrixURI) SecondaryIdentifier() string { - if uri == nil || uri.Sigil2 == 0 { + if uri.Sigil2 == 0 { return "" } return fmt.Sprintf("%c%s", uri.Sigil2, uri.MXID2) @@ -119,7 +110,7 @@ func (uri *MatrixURI) SecondaryIdentifier() string { // UserID returns the user ID from the URI if the primary identifier is a user ID. func (uri *MatrixURI) UserID() UserID { - if uri != nil && uri.Sigil1 == '@' { + if uri.Sigil1 == '@' { return UserID(uri.PrimaryIdentifier()) } return "" @@ -127,7 +118,7 @@ func (uri *MatrixURI) UserID() UserID { // RoomID returns the room ID from the URI if the primary identifier is a room ID. func (uri *MatrixURI) RoomID() RoomID { - if uri != nil && uri.Sigil1 == '!' { + if uri.Sigil1 == '!' { return RoomID(uri.PrimaryIdentifier()) } return "" @@ -135,7 +126,7 @@ func (uri *MatrixURI) RoomID() RoomID { // RoomAlias returns the room alias from the URI if the primary identifier is a room alias. func (uri *MatrixURI) RoomAlias() RoomAlias { - if uri != nil && uri.Sigil1 == '#' { + if uri.Sigil1 == '#' { return RoomAlias(uri.PrimaryIdentifier()) } return "" @@ -143,7 +134,7 @@ func (uri *MatrixURI) RoomAlias() RoomAlias { // EventID returns the event ID from the URI if the primary identifier is a room ID or alias and the secondary identifier is an event ID. func (uri *MatrixURI) EventID() EventID { - if uri != nil && (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { + if (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { return EventID(uri.SecondaryIdentifier()) } return "" @@ -210,14 +201,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[1]) == 0 { return nil, ErrEmptySecondSegment } - var err error - parsed.MXID1, err = url.PathUnescape(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err) - } + parsed.MXID1 = parts[1] // Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier - if parsed.Sigil1 == '!' && len(parts) == 4 { + if (parsed.Sigil1 == '!' || parsed.Sigil1 == '#') && len(parts) == 4 { // a: find the sigil from the third segment switch parts[2] { case "e", "event": @@ -230,10 +217,7 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[3]) == 0 { return nil, ErrEmptyFourthSegment } - parsed.MXID2, err = url.PathUnescape(parts[3]) - if err != nil { - return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err) - } + parsed.MXID2 = parts[3] } // Step 7: parse the query and extract via and action items diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go index 90a0754d..d26d4bfd 100644 --- a/id/matrixuri_test.go +++ b/id/matrixuri_test.go @@ -16,11 +16,12 @@ import ( ) var ( - roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} - roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} - roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} - roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} - userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} + roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} + roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} + roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} + roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} + roomAliasEventLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} + userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} escapeRoomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "meow & 🐈️:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF/dtndJ0j9je+kIK3XpV1s"} ) @@ -30,6 +31,7 @@ func TestMatrixURI_MatrixToURL(t *testing.T) { assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%23someroom:example.org", roomAliasLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.MatrixToURL()) + assert.Equal(t, "https://matrix.to/#/%23someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/@user:example.org", userLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%21meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/$uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.MatrixToURL()) } @@ -39,6 +41,7 @@ func TestMatrixURI_String(t *testing.T) { assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.String()) assert.Equal(t, "matrix:r/someroom:example.org", roomAliasLink.String()) assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.String()) + assert.Equal(t, "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.String()) assert.Equal(t, "matrix:u/user:example.org", userLink.String()) assert.Equal(t, "matrix:roomid/meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/e/uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.String()) } @@ -77,12 +80,8 @@ func TestParseMatrixURI_RoomID(t *testing.T) { parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedVia) - parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org") - require.NoError(t, err) - require.NotNil(t, parsedEncoded) assert.Equal(t, roomIDLink, *parsed) - assert.Equal(t, roomIDLink, *parsedEncoded) assert.Equal(t, roomIDViaLink, *parsedVia) } @@ -99,11 +98,19 @@ func TestParseMatrixURI_UserID(t *testing.T) { } func TestParseMatrixURI_EventID(t *testing.T) { - parsed, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed) + require.NotNil(t, parsed1) + parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NoError(t, err) + require.NotNil(t, parsed2) + parsed3, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NoError(t, err) + require.NotNil(t, parsed3) - assert.Equal(t, roomIDEventLink, *parsed) + assert.Equal(t, roomAliasEventLink, *parsed1) + assert.Equal(t, roomAliasEventLink, *parsed2) + assert.Equal(t, roomIDEventLink, *parsed3) } func TestParseMatrixToURL_RoomAlias(t *testing.T) { @@ -151,13 +158,21 @@ func TestParseMatrixToURL_UserID(t *testing.T) { } func TestParseMatrixToURL_EventID(t *testing.T) { - parsed, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + parsed1, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed) - parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NotNil(t, parsed1) + parsed2, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsedEncoded) + require.NotNil(t, parsed2) + parsed1Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NoError(t, err) + require.NotNil(t, parsed1) + parsed2Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NoError(t, err) + require.NotNil(t, parsed2) - assert.Equal(t, roomIDEventLink, *parsed) - assert.Equal(t, roomIDEventLink, *parsedEncoded) + assert.Equal(t, roomAliasEventLink, *parsed1) + assert.Equal(t, roomAliasEventLink, *parsed1Encoded) + assert.Equal(t, roomIDEventLink, *parsed2) + assert.Equal(t, roomIDEventLink, *parsed2Encoded) } diff --git a/id/opaque.go b/id/opaque.go index c1ad4988..16863b95 100644 --- a/id/opaque.go +++ b/id/opaque.go @@ -32,17 +32,11 @@ type EventID string // https://github.com/matrix-org/matrix-doc/pull/2716 type BatchID string -// A DelayID is a string identifying a delayed event. -type DelayID string - func (roomID RoomID) String() string { return string(roomID) } func (roomID RoomID) URI(via ...string) *MatrixURI { - if roomID == "" { - return nil - } return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], @@ -51,11 +45,6 @@ func (roomID RoomID) URI(via ...string) *MatrixURI { } func (roomID RoomID) EventURI(eventID EventID, via ...string) *MatrixURI { - if roomID == "" { - return nil - } else if eventID == "" { - return roomID.URI(via...) - } return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], @@ -70,20 +59,13 @@ func (roomAlias RoomAlias) String() string { } func (roomAlias RoomAlias) URI() *MatrixURI { - if roomAlias == "" { - return nil - } return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], } } -// Deprecated: room alias event links should not be used. Use room IDs instead. func (roomAlias RoomAlias) EventURI(eventID EventID) *MatrixURI { - if roomAlias == "" { - return nil - } return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], diff --git a/id/roomversion.go b/id/roomversion.go deleted file mode 100644 index 578c10bd..00000000 --- a/id/roomversion.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package id - -import ( - "errors" - "fmt" - "slices" -) - -type RoomVersion string - -const ( - RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1 - RoomV1 RoomVersion = "1" - RoomV2 RoomVersion = "2" - RoomV3 RoomVersion = "3" - RoomV4 RoomVersion = "4" - RoomV5 RoomVersion = "5" - RoomV6 RoomVersion = "6" - RoomV7 RoomVersion = "7" - RoomV8 RoomVersion = "8" - RoomV9 RoomVersion = "9" - RoomV10 RoomVersion = "10" - RoomV11 RoomVersion = "11" - RoomV12 RoomVersion = "12" -) - -func (rv RoomVersion) Equals(versions ...RoomVersion) bool { - return slices.Contains(versions, rv) -} - -func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool { - return !rv.Equals(versions...) -} - -var ErrUnknownRoomVersion = errors.New("unknown room version") - -func (rv RoomVersion) unknownVersionError() error { - return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv) -} - -func (rv RoomVersion) IsKnown() bool { - switch rv { - case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12: - return true - default: - return false - } -} - -type StateResVersion int - -const ( - // StateResV1 is the original state resolution algorithm. - StateResV1 StateResVersion = 0 - // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759 - StateResV2 StateResVersion = 1 - // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297 - StateResV2_1 StateResVersion = 2 -) - -// StateResVersion returns the version of the state resolution algorithm used by this room version. -func (rv RoomVersion) StateResVersion() StateResVersion { - switch rv { - case RoomV0, RoomV1: - return StateResV1 - case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11: - return StateResV2 - case RoomV12: - return StateResV2_1 - default: - panic(rv.unknownVersionError()) - } -} - -type EventIDFormat int - -const ( - // EventIDFormatCustom is the original format used by room v1 and v2. - // Event IDs in this format are an arbitrary string followed by a colon and the server name. - EventIDFormatCustom EventIDFormat = 0 - // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659. - // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event. - EventIDFormatBase64 EventIDFormat = 1 - // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002. - // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event. - EventIDFormatURLSafeBase64 EventIDFormat = 2 -) - -// EventIDFormat returns the format of event IDs used by this room version. -func (rv RoomVersion) EventIDFormat() EventIDFormat { - switch rv { - case RoomV0, RoomV1, RoomV2: - return EventIDFormatCustom - case RoomV3: - return EventIDFormatBase64 - default: - return EventIDFormatURLSafeBase64 - } -} - -///////////////////// -// Room v5 changes // -///////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/2077 - -// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys -// must be enforced on received events. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076 -func (rv RoomVersion) EnforceSigningKeyValidity() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4) -} - -///////////////////// -// Room v6 changes // -///////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/2240 - -// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased -// to only always allow servers to modify the state event with their own server name as state key. -// This also implies that the `aliases` field is protected from redactions. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432 -func (rv RoomVersion) SpecialCasedAliasesAuth() bool { - return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) -} - -// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540 -func (rv RoomVersion) ForbidFloatsAndBigInts() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) -} - -// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth. -// However, the field is not protected from redactions. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209 -func (rv RoomVersion) NotificationsPowerLevels() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) -} - -///////////////////// -// Room v7 changes // -///////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/2998 - -// Knocks returns true if the `knock` join rule is supported. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403 -func (rv RoomVersion) Knocks() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6) -} - -///////////////////// -// Room v8 changes // -///////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/3289 - -// RestrictedJoins returns true if the `restricted` join rule is supported. -// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083 -func (rv RoomVersion) RestrictedJoins() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7) -} - -///////////////////// -// Room v9 changes // -///////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/3375 - -// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375 -func (rv RoomVersion) RestrictedJoinsFix() bool { - return rv.RestrictedJoins() && rv != RoomV8 -} - -////////////////////// -// Room v10 changes // -////////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/3604 - -// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings). -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667 -func (rv RoomVersion) ValidatePowerLevelInts() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) -} - -// KnockRestricted returns true if the `knock_restricted` join rule is supported. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787 -func (rv RoomVersion) KnockRestricted() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) -} - -////////////////////// -// Room v11 changes // -////////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/3820 - -// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175 -func (rv RoomVersion) CreatorInContent() bool { - return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) -} - -// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level. -// The redaction protection is also moved from the top level to the content field. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174 -// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection). -func (rv RoomVersion) RedactsInContent() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) -} - -// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied. -// -// Specifically: -// -// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected. -// * the entire content of `m.room.create` is protected. -// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field. -// * the `m.room.power_levels` event protects the `invite` field in content. -// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176, -// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and -// https://github.com/matrix-org/matrix-spec-proposals/pull/3989 -func (rv RoomVersion) UpdatedRedactionRules() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) -} - -////////////////////// -// Room v12 changes // -////////////////////// -// https://github.com/matrix-org/matrix-spec-proposals/pull/4304 - -// Return value of StateResVersion was changed to StateResV2_1 - -// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level. -// This also implies that the `m.room.create` event has an `additional_creators` field, -// and that the creators can't be present in the `m.room.power_levels` event. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289 -func (rv RoomVersion) PrivilegedRoomCreators() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) -} - -// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event. -// This also implies that `m.room.create` events do not have a `room_id` field. -// -// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291 -func (rv RoomVersion) RoomIDIsCreateEventID() bool { - return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) -} diff --git a/id/servername.go b/id/servername.go deleted file mode 100644 index 923705b6..00000000 --- a/id/servername.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package id - -import ( - "regexp" - "strconv" -) - -type ParsedServerNameType int - -const ( - ServerNameDNS ParsedServerNameType = iota - ServerNameIPv4 - ServerNameIPv6 -) - -type ParsedServerName struct { - Type ParsedServerNameType - Host string - Port int -} - -var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`) - -func ValidateServerName(serverName string) bool { - return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName) -} - -func ParseServerName(serverName string) *ParsedServerName { - if len(serverName) > 255 || len(serverName) < 1 { - return nil - } - match := ServerNameRegex.FindStringSubmatch(serverName) - if len(match) != 5 { - return nil - } - port, _ := strconv.Atoi(match[4]) - parsed := &ParsedServerName{ - Port: port, - } - switch { - case match[1] != "": - parsed.Type = ServerNameIPv6 - parsed.Host = match[1] - case match[2] != "": - parsed.Type = ServerNameIPv4 - parsed.Host = match[2] - case match[3] != "": - parsed.Type = ServerNameDNS - parsed.Host = match[3] - } - return parsed -} diff --git a/id/trust.go b/id/trust.go index 6255093e..04f6e36b 100644 --- a/id/trust.go +++ b/id/trust.go @@ -16,7 +16,6 @@ type TrustState int const ( TrustStateBlacklisted TrustState = -100 - TrustStateDeviceKeyMismatch TrustState = -5 TrustStateUnset TrustState = 0 TrustStateUnknownDevice TrustState = 10 TrustStateForwarded TrustState = 20 @@ -24,7 +23,7 @@ const ( TrustStateCrossSignedTOFU TrustState = 100 TrustStateCrossSignedVerified TrustState = 200 TrustStateVerified TrustState = 300 - TrustStateInvalid TrustState = -2147483647 + TrustStateInvalid TrustState = (1 << 31) - 1 ) func (ts *TrustState) UnmarshalText(data []byte) error { @@ -45,8 +44,6 @@ func ParseTrustState(val string) TrustState { switch strings.ToLower(val) { case "blacklisted": return TrustStateBlacklisted - case "device-key-mismatch": - return TrustStateDeviceKeyMismatch case "unverified": return TrustStateUnset case "cross-signed-untrusted": @@ -70,8 +67,6 @@ func (ts TrustState) String() string { switch ts { case TrustStateBlacklisted: return "blacklisted" - case TrustStateDeviceKeyMismatch: - return "device-key-mismatch" case TrustStateUnset: return "unverified" case TrustStateCrossSignedUntrusted: diff --git a/id/userid.go b/id/userid.go index 726a0d58..53b68b96 100644 --- a/id/userid.go +++ b/id/userid.go @@ -30,11 +30,10 @@ func NewEncodedUserID(localpart, homeserver string) UserID { } var ( - ErrInvalidUserID = errors.New("is not a valid user ID") - ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") - ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") - ErrEmptyLocalpart = errors.New("empty localparts are not allowed") - ErrNoncompliantServerPart = errors.New("is not a valid server name") + ErrInvalidUserID = errors.New("is not a valid user ID") + ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") + ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") + ErrEmptyLocalpart = errors.New("empty localparts are not allowed") ) // ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format @@ -44,10 +43,10 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, } sigil = identifier[0] strIdentifier := string(identifier) - colonIdx := strings.IndexByte(strIdentifier, ':') - if colonIdx > 0 { - localpart = strIdentifier[1:colonIdx] - homeserver = strIdentifier[colonIdx+1:] + if strings.ContainsRune(strIdentifier, ':') { + parts := strings.SplitN(strIdentifier, ":", 2) + localpart = parts[0][1:] + homeserver = parts[1] } else { localpart = strIdentifier[1:] } @@ -82,9 +81,6 @@ func (userID UserID) Homeserver() string { // // This does not parse or validate the user ID. Use the ParseAndValidate method if you want to ensure the user ID is valid first. func (userID UserID) URI() *MatrixURI { - if userID == "" { - return nil - } return &MatrixURI{ Sigil1: '@', MXID1: string(userID)[1:], @@ -104,32 +100,21 @@ func ValidateUserLocalpart(localpart string) error { return nil } -// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts. -// This should be used with care: there are real users still using historical localparts. -func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidateRelaxed() +// ParseAndValidate parses the user ID into the localpart and server name like Parse, +// and also validates that the localpart is allowed according to the user identifiers spec. +func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { + localpart, homeserver, err = userID.Parse() if err == nil { err = ValidateUserLocalpart(localpart) } - return -} - -// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse, -// and also validates that the user ID is not too long and that the server name is valid. -func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) { - if len(userID) > UserIDMaxLength { + if err == nil && len(userID) > UserIDMaxLength { err = ErrUserIDTooLong - return - } - localpart, homeserver, err = userID.Parse() - if err == nil && !ValidateServerName(homeserver) { - err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart) } return } func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidateStrict() + localpart, homeserver, err = userID.ParseAndValidate() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } @@ -219,15 +204,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) + return "", fmt.Errorf("Byte pos %d: Invalid byte", i) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("unexpected end of string after underscore at %d", i) + return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) + return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -237,7 +222,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) + return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/id/userid_test.go b/id/userid_test.go index 57a88066..359bc687 100644 --- a/id/userid_test.go +++ b/id/userid_test.go @@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) { assert.True(t, errors.Is(err, id.ErrInvalidUserID)) } -func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) { +func TestUserID_ParseAndValidate_Invalid(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) } -func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) { +func TestUserID_ParseAndValidate_Empty(t *testing.T) { const inputUserID = "@:ponies.im" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) } -func TestUserID_ParseAndValidateStrict_Long(t *testing.T) { +func TestUserID_ParseAndValidate_Long(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) } -func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) { +func TestUserID_ParseAndValidate_NotLong(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.NoError(t, err) } @@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) { const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) - parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict() + parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() assert.NoError(t, err) assert.Equal(t, encodedLocalpart, parsedLocalpart) assert.Equal(t, inputServerName, parsedServerName) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 4d2bc7cf..f2591428 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,107 +8,51 @@ package mediaproxy import ( "context" + "encoding/json" "errors" "fmt" "io" "mime" "mime/multipart" + "net" "net/http" "net/textproto" - "net/url" - "os" "strconv" "strings" "time" + "github.com/gorilla/mux" "github.com/rs/zerolog" - "github.com/rs/zerolog/hlog" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/ptr" - "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" - "maunium.net/go/mautrix/id" ) type GetMediaResponse interface { isGetMediaResponse() } -func (*GetMediaResponseURL) isGetMediaResponse() {} -func (*GetMediaResponseData) isGetMediaResponse() {} -func (*GetMediaResponseCallback) isGetMediaResponse() {} -func (*GetMediaResponseFile) isGetMediaResponse() {} +func (*GetMediaResponseURL) isGetMediaResponse() {} +func (*GetMediaResponseData) isGetMediaResponse() {} type GetMediaResponseURL struct { URL string ExpiresAt time.Time } -type GetMediaResponseWriter interface { - GetMediaResponse - io.WriterTo - GetContentType() string - GetContentLength() int64 -} - -var ( - _ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil) - _ GetMediaResponseWriter = (*GetMediaResponseData)(nil) -) - type GetMediaResponseData struct { Reader io.ReadCloser ContentType string ContentLength int64 } -func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) { - return io.Copy(w, d.Reader) -} - -func (d *GetMediaResponseData) GetContentType() string { - return d.ContentType -} - -func (d *GetMediaResponseData) GetContentLength() int64 { - return d.ContentLength -} - -type GetMediaResponseCallback struct { - Callback func(w io.Writer) (int64, error) - ContentType string - ContentLength int64 -} - -func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) { - return d.Callback(w) -} - -func (d *GetMediaResponseCallback) GetContentLength() int64 { - return d.ContentLength -} - -func (d *GetMediaResponseCallback) GetContentType() string { - return d.ContentType -} - -type FileMeta struct { - ContentType string - ReplacementFile string -} - -type GetMediaResponseFile struct { - Callback func(w *os.File) (*FileMeta, error) -} - -type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) +type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) type MediaProxy struct { - KeyServer *federation.KeyServer - ServerAuth *federation.ServerAuth + KeyServer *federation.KeyServer + ProxyClient *http.Client + + ForceProxyLegacyFederation bool GetMedia GetMediaFunc PrepareProxyRequest func(*http.Request) @@ -116,8 +60,9 @@ type MediaProxy struct { serverName string serverKey *federation.SigningKey - FederationRouter *http.ServeMux - ClientMediaRouter *http.ServeMux + FederationRouter *mux.Router + LegacyMediaRouter *mux.Router + ClientMediaRouter *mux.Router } func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { @@ -125,10 +70,18 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx if err != nil { return nil, err } - mp := &MediaProxy{ + return &MediaProxy{ serverName: serverName, serverKey: parsed, GetMedia: getMedia, + ProxyClient: &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: false, + }, + Timeout: 60 * time.Second, + }, KeyServer: &federation.KeyServer{ KeyProvider: &federation.StaticServerKey{ ServerName: serverName, @@ -140,27 +93,13 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), }, }, - } - mp.FederationRouter = http.NewServeMux() - mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) - mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation) - mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) - mp.ClientMediaRouter = http.NewServeMux() - mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) - mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia) - mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia) - mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported) - mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported) - mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported) - mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported) - mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported) - return mp, nil + }, nil } type BasicConfig struct { ServerName string `yaml:"server_name" json:"server_name"` ServerKey string `yaml:"server_key" json:"server_key"` - FederationAuth bool `yaml:"federation_auth" json:"federation_auth"` + AllowProxy bool `yaml:"allow_proxy" json:"allow_proxy"` WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"` } @@ -169,12 +108,12 @@ func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) if err != nil { return nil, err } + if !cfg.AllowProxy { + mp.DisallowProxying() + } if cfg.WellKnownResponse != "" { mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse } - if cfg.FederationAuth { - mp.EnableServerAuth(nil, nil) - } return mp, nil } @@ -184,8 +123,8 @@ type ServerConfig struct { } func (mp *MediaProxy) Listen(cfg ServerConfig) error { - router := http.NewServeMux() - mp.RegisterRoutes(router, zerolog.Nop()) + router := mux.NewRouter() + mp.RegisterRoutes(router) return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) } @@ -197,183 +136,99 @@ func (mp *MediaProxy) GetServerKey() *federation.SigningKey { return mp.serverKey } -func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) { - if keyCache == nil { - keyCache = federation.NewInMemoryCache() - } - if client == nil { - resCache, _ := keyCache.(federation.ResolutionCache) - client = federation.NewClient(mp.serverName, mp.serverKey, resCache) - } - mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string { - return mp.GetServerName() - }) +func (mp *MediaProxy) DisallowProxying() { + mp.ProxyClient = nil } -func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) { - errorBodies := exhttp.ErrorBodies{ - NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), - MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), +func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { + if mp.FederationRouter == nil { + mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() } - router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware( - mp.FederationRouter, - exhttp.StripPrefix("/_matrix/federation"), - hlog.NewHandler(log), - hlog.RequestIDHandler("request_id", "Request-Id"), - requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), - exhttp.HandleErrors(errorBodies), - )) - router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware( - mp.ClientMediaRouter, - exhttp.StripPrefix("/_matrix/client/v1/media"), - hlog.NewHandler(log), - hlog.RequestIDHandler("request_id", "Request-Id"), - exhttp.CORSMiddleware, - requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), - exhttp.HandleErrors(errorBodies), - )) - mp.KeyServer.Register(router, log) + if mp.LegacyMediaRouter == nil { + mp.LegacyMediaRouter = router.PathPrefix("/_matrix/media").Subrouter() + } + if mp.ClientMediaRouter == nil { + mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() + } + + mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) + mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) + addClientRoutes := func(router *mux.Router, prefix string) { + router.HandleFunc(prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) + router.HandleFunc(prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) + router.HandleFunc(prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) + router.HandleFunc(prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) + router.HandleFunc(prefix+"/upload", mp.UploadNotSupported).Methods(http.MethodPost) + router.HandleFunc(prefix+"/create", mp.UploadNotSupported).Methods(http.MethodPost) + router.HandleFunc(prefix+"/config", mp.UploadNotSupported).Methods(http.MethodGet) + router.HandleFunc(prefix+"/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) + } + addClientRoutes(mp.LegacyMediaRouter, "/v3") + addClientRoutes(mp.LegacyMediaRouter, "/r0") + addClientRoutes(mp.LegacyMediaRouter, "/v1") + addClientRoutes(mp.ClientMediaRouter, "") + mp.LegacyMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) + mp.LegacyMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) + mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) + mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + corsMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") + w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';") + next.ServeHTTP(w, r) + }) + } + mp.LegacyMediaRouter.Use(corsMiddleware) + mp.ClientMediaRouter.Use(corsMiddleware) + mp.KeyServer.Register(router) } -var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") - -func queryToMap(vals url.Values) map[string]string { - m := make(map[string]string, len(vals)) - for k, v := range vals { - m[k] = v[0] - } - return m -} - -func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { - mediaID := r.PathValue("mediaID") - if !id.IsValidMediaID(mediaID) { - mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w) - return nil - } - resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query())) - if err != nil { - var mautrixRespError mautrix.RespError - if errors.Is(err, ErrInvalidMediaIDSyntax) { - mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) - } else if errors.As(err, &mautrixRespError) { - mautrixRespError.Write(w) - } else { - zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") - mautrix.MNotFound.WithMessage("Media not found").Write(w) - } - return nil - } - return resp -} - -func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer { - mpw := multipart.NewWriter(w) - w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1)) - w.WriteHeader(http.StatusOK) - metaPart, err := mpw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {"application/json"}, - }) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field") - return nil - } - _, err = metaPart.Write([]byte(`{}`)) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field") - return nil - } - return mpw -} - -func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { - if mp.ServerAuth != nil { - var err *mautrix.RespError - r, err = mp.ServerAuth.Authenticate(r) - if err != nil { - err.Write(w) - return - } - } - ctx := r.Context() +func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) { log := zerolog.Ctx(ctx) - - resp := mp.getMedia(w, r) - if resp == nil { - return - } - - var mpw *multipart.Writer - if urlResp, ok := resp.(*GetMediaResponseURL); ok { - mpw = startMultipart(ctx, w) - if mpw == nil { - return - } - _, err := mpw.CreatePart(textproto.MIMEHeader{ - "Location": {urlResp.URL}, - }) - if err != nil { - log.Err(err).Msg("Failed to create multipart redirect field") - return - } - } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { - responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - mpw = startMultipart(ctx, w) - if mpw == nil { - return fmt.Errorf("failed to start multipart writer") - } - dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {mimeType}, - }) - if err != nil { - return fmt.Errorf("failed to create multipart data field: %w", err) - } - _, err = wt.WriteTo(dataPart) - return err - }) - if err != nil { - log.Err(err).Msg("Failed to do media proxy with temp file") - if !responseStarted { - var mautrixRespError mautrix.RespError - if errors.As(err, &mautrixRespError) { - mautrixRespError.Write(w) - } else { - mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) - } - } - return - } - } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { - mpw = startMultipart(ctx, w) - if mpw == nil { - return - } - dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {dataResp.GetContentType()}, - }) - if err != nil { - log.Err(err).Msg("Failed to create multipart data field") - return - } - _, err = dataResp.WriteTo(dataPart) - if err != nil { - log.Err(err).Msg("Failed to write multipart data field") - return - } - } else { - panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) - } - err := mpw.Close() + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - log.Err(err).Msg("Failed to close multipart writer") + log.Err(err).Str("url", url).Msg("Failed to create proxy request") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Failed to create proxy request", + }) return } -} - -func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) { - w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (media proxy)") + if mp.PrepareProxyRequest != nil { + mp.PrepareProxyRequest(req) + } + resp, err := mp.ProxyClient.Do(req) + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if err != nil { + log.Err(err).Str("url", url).Msg("Failed to proxy download") + jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Failed to proxy download", + }) + return + } else if resp.StatusCode != http.StatusOK { + log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download") + jsonResponse(w, resp.StatusCode, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Unexpected status code proxying download", + }) + return + } + w.Header()["Content-Type"] = resp.Header["Content-Type"] + w.Header()["Content-Length"] = resp.Header["Content-Length"] + w.Header()["Last-Modified"] = resp.Header["Last-Modified"] + w.Header()["Cache-Control"] = resp.Header["Cache-Control"] contentDisposition := "attachment" - switch mimeType { + switch resp.Header.Get("Content-Type") { case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", @@ -386,14 +241,113 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin }) } w.Header().Set("Content-Disposition", contentDisposition) - w.Header().Set("Content-Type", mimeType) + w.WriteHeader(http.StatusOK) + _, err = io.Copy(w, resp.Body) + if err != nil { + log.Debug().Err(err).Msg("Failed to write proxy response") + } +} + +type ResponseError struct { + Status int + Data any +} + +func (err *ResponseError) Error() string { + return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data) +} + +var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") + +func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { + mediaID := mux.Vars(r)["mediaID"] + resp, err := mp.GetMedia(r.Context(), mediaID) + if err != nil { + var respError *ResponseError + if errors.Is(err, ErrInvalidMediaIDSyntax) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), + }) + } else if errors.As(err, &respError) { + jsonResponse(w, respError.Status, respError.Data) + } else { + zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Media not found", + }) + } + return nil + } + return resp +} + +func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := zerolog.Ctx(ctx) + // TODO check destination header in X-Matrix auth + + resp := mp.getMedia(w, r) + if resp == nil { + return + } + + mpw := multipart.NewWriter(w) + w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1)) + w.WriteHeader(http.StatusOK) + metaPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart metadata field") + return + } + _, err = metaPart.Write([]byte(`{}`)) + if err != nil { + log.Err(err).Msg("Failed to write multipart metadata field") + return + } + if urlResp, ok := resp.(*GetMediaResponseURL); ok { + _, err = mpw.CreatePart(textproto.MIMEHeader{ + "Location": {urlResp.URL}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart redirect field") + return + } + } else if dataResp, ok := resp.(*GetMediaResponseData); ok { + dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {dataResp.ContentType}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart data field") + return + } + _, err = io.Copy(dataPart, dataResp.Reader) + if err != nil { + log.Err(err).Msg("Failed to write multipart data field") + return + } + } else { + panic("unknown GetMediaResponse type") + } + err = mpw.Close() + if err != nil { + log.Err(err).Msg("Failed to close multipart writer") + return + } } func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) - if r.PathValue("serverName") != mp.serverName { - mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) + vars := mux.Vars(r) + if vars["serverName"] != mp.serverName { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), + }) return } resp := mp.getMedia(w, r) @@ -402,6 +356,13 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } if urlResp, ok := resp.(*GetMediaResponseURL); ok { + // Proxy if the config allows proxying and the request doesn't allow redirects. + // In any other case, redirect to the URL. + isFederated := strings.HasPrefix(r.Header.Get("Authorization"), "X-Matrix") + if mp.ProxyClient != nil && (r.URL.Query().Get("allow_redirect") != "true" || (mp.ForceProxyLegacyFederation && isFederated)) { + mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"]) + return + } w.Header().Set("Location", urlResp.URL) expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds() if urlResp.ExpiresAt.IsZero() { @@ -413,113 +374,51 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") } w.WriteHeader(http.StatusTemporaryRedirect) - } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { - responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - mp.addHeaders(w, mimeType, r.PathValue("fileName")) - w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) - w.WriteHeader(http.StatusOK) - _, err := wt.WriteTo(w) - return err - }) - if err != nil { - log.Err(err).Msg("Failed to do media proxy with temp file") - if !responseStarted { - var mautrixRespError mautrix.RespError - if errors.As(err, &mautrixRespError) { - mautrixRespError.Write(w) - } else { - mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) - } - } - } - } else if writerResp, ok := resp.(GetMediaResponseWriter); ok { - if dataResp, ok := writerResp.(*GetMediaResponseData); ok { - defer dataResp.Reader.Close() - } - mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName")) - if writerResp.GetContentLength() != 0 { - w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10)) + } else if dataResp, ok := resp.(*GetMediaResponseData); ok { + w.Header().Set("Content-Type", dataResp.ContentType) + if dataResp.ContentLength != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(dataResp.ContentLength, 10)) } w.WriteHeader(http.StatusOK) - _, err := writerResp.WriteTo(w) + _, err := io.Copy(w, dataResp.Reader) if err != nil { log.Err(err).Msg("Failed to write media data") } } else { - panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) + panic("unknown GetMediaResponse type") } } -func doTempFileDownload( - data *GetMediaResponseFile, - respond func(w io.WriterTo, size int64, mimeType string) error, -) (bool, error) { - tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*") - if err != nil { - return false, fmt.Errorf("failed to create temp file: %w", err) - } - origTempFile := tempFile - defer func() { - _ = origTempFile.Close() - _ = os.Remove(origTempFile.Name()) - }() - meta, err := data.Callback(tempFile) - if err != nil { - return false, err - } - if meta.ReplacementFile != "" { - tempFile, err = os.Open(meta.ReplacementFile) - if err != nil { - return false, fmt.Errorf("failed to open replacement file: %w", err) - } - defer func() { - _ = tempFile.Close() - _ = os.Remove(origTempFile.Name()) - }() - } else { - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return false, fmt.Errorf("failed to seek to start of temp file: %w", err) - } - } - fileInfo, err := tempFile.Stat() - if err != nil { - return false, fmt.Errorf("failed to stat temp file: %w", err) - } - mimeType := meta.ContentType - if mimeType == "" { - buf := make([]byte, 512) - n, err := tempFile.Read(buf) - if err != nil { - return false, fmt.Errorf("failed to read temp file to detect mime: %w", err) - } - buf = buf[:n] - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return false, fmt.Errorf("failed to seek to start of temp file: %w", err) - } - mimeType = http.DetectContentType(buf) - } - err = respond(tempFile, fileInfo.Size(), mimeType) - if err != nil { - return true, err - } - return true, nil +func jsonResponse(w http.ResponseWriter, status int, response interface{}) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(response) } -var ( - ErrUploadNotSupported = mautrix.MUnrecognized. - WithMessage("This is a media proxy and does not support media uploads."). - WithStatus(http.StatusNotImplemented) - ErrPreviewURLNotSupported = mautrix.MUnrecognized. - WithMessage("This is a media proxy and does not support URL previews."). - WithStatus(http.StatusNotImplemented) -) - func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { - ErrUploadNotSupported.Write(w) + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "This is a media proxy and does not support media uploads.", + }) } func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { - ErrPreviewURLNotSupported.Write(w) + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "This is a media proxy and does not support URL previews.", + }) +} + +func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Unrecognized endpoint", + }) +} + +func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Invalid method for endpoint", + }) } diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go deleted file mode 100644 index 507c24a5..00000000 --- a/mockserver/mockserver.go +++ /dev/null @@ -1,307 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mockserver - -import ( - "context" - "encoding/json" - "fmt" - "io" - "maps" - "net/http" - "net/http/httptest" - "strings" - "testing" - - globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log - "github.com/stretchr/testify/require" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/random" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/cryptohelper" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func mustDecode(r *http.Request, data any) { - exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data)) -} - -type userAndDeviceID struct { - UserID id.UserID - DeviceID id.DeviceID -} - -type MockServer struct { - Router *http.ServeMux - Server *httptest.Server - - AccessTokenToUserID map[string]userAndDeviceID - DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event - AccountData map[id.UserID]map[event.Type]json.RawMessage - DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys - OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey - MasterKeys map[id.UserID]mautrix.CrossSigningKeys - SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys - UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys - - PopOTKs bool - MemoryStore bool -} - -func Create(t testing.TB) *MockServer { - t.Helper() - - server := MockServer{ - AccessTokenToUserID: map[string]userAndDeviceID{}, - DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, - AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, - DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, - OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, - MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - PopOTKs: true, - MemoryStore: true, - } - - router := http.NewServeMux() - router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) - router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) - router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim) - router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) - router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) - router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) - router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) - router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) - server.Router = router - server.Server = httptest.NewServer(router) - t.Cleanup(server.Server.Close) - return &server -} - -func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID { - authHeader := r.Header.Get("Authorization") - authHeader = strings.TrimPrefix(authHeader, "Bearer ") - userID, ok := ms.AccessTokenToUserID[authHeader] - if !ok { - panic("no user ID found for access token " + authHeader) - } - return userID -} - -func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) -} - -func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { - var loginReq mautrix.ReqLogin - mustDecode(r, &loginReq) - - deviceID := loginReq.DeviceID - if deviceID == "" { - deviceID = id.DeviceID(random.String(10)) - } - - accessToken := random.String(30) - userID := id.UserID(loginReq.Identifier.User) - ms.AccessTokenToUserID[accessToken] = userAndDeviceID{ - UserID: userID, - DeviceID: deviceID, - } - - exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{ - AccessToken: accessToken, - DeviceID: deviceID, - UserID: userID, - }) -} - -func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqSendToDevice - mustDecode(r, &req) - evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} - - for user, devices := range req.Messages { - for device, content := range devices { - if _, ok := ms.DeviceInbox[user]; !ok { - ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{} - } - content.ParseRaw(evtType) - ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ - Sender: ms.getUserID(r).UserID, - Type: evtType, - Content: *content, - }) - } - } - ms.emptyResp(w, r) -} - -func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { - userID := id.UserID(r.PathValue("userID")) - eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} - - jsonData, _ := io.ReadAll(r.Body) - if _, ok := ms.AccountData[userID]; !ok { - ms.AccountData[userID] = map[event.Type]json.RawMessage{} - } - ms.AccountData[userID][eventType] = json.RawMessage(jsonData) - ms.emptyResp(w, r) -} - -func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqQueryKeys - mustDecode(r, &req) - resp := mautrix.RespQueryKeys{ - MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, - } - for user := range req.DeviceKeys { - resp.MasterKeys[user] = ms.MasterKeys[user] - resp.UserSigningKeys[user] = ms.UserSigningKeys[user] - resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] - resp.DeviceKeys[user] = ms.DeviceKeys[user] - } - exhttp.WriteJSONResponse(w, http.StatusOK, &resp) -} - -func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqClaimKeys - mustDecode(r, &req) - resp := mautrix.RespClaimKeys{ - OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, - } - for user, devices := range req.OneTimeKeys { - resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} - for device := range devices { - keys := ms.OneTimeKeys[user][device] - for keyID, key := range keys { - if ms.PopOTKs { - delete(keys, keyID) - } - resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{ - keyID: key, - } - break - } - } - } - exhttp.WriteJSONResponse(w, http.StatusOK, &resp) -} - -func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqUploadKeys - mustDecode(r, &req) - - uid := ms.getUserID(r) - userID := uid.UserID - if _, ok := ms.DeviceKeys[userID]; !ok { - ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} - } - if _, ok := ms.OneTimeKeys[userID]; !ok { - ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} - } - - if req.DeviceKeys != nil { - ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys - } - otks, ok := ms.OneTimeKeys[userID][uid.DeviceID] - if !ok { - otks = map[id.KeyID]mautrix.OneTimeKey{} - ms.OneTimeKeys[userID][uid.DeviceID] = otks - } - if req.OneTimeKeys != nil { - maps.Copy(otks, req.OneTimeKeys) - } - - exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{ - OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)}, - }) -} - -func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.UploadCrossSigningKeysReq[any] - mustDecode(r, &req) - - userID := ms.getUserID(r).UserID - ms.MasterKeys[userID] = req.Master - ms.SelfSigningKeys[userID] = req.SelfSigning - ms.UserSigningKeys[userID] = req.UserSigning - - ms.emptyResp(w, r) -} - -func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { - t.Helper() - if ctx == nil { - ctx = context.TODO() - } - client, err := mautrix.NewClient(ms.Server.URL, "", "") - require.NoError(t, err) - client.Client = ms.Server.Client() - - _, err = client.Login(ctx, &mautrix.ReqLogin{ - Type: mautrix.AuthTypePassword, - Identifier: mautrix.UserIdentifier{ - Type: mautrix.IdentifierTypeUser, - User: userID.String(), - }, - DeviceID: deviceID, - Password: "password", - StoreCredentials: true, - }) - require.NoError(t, err) - - var store any - if ms.MemoryStore { - store = crypto.NewMemoryStore(nil) - client.StateStore = mautrix.NewMemoryStateStore() - } else { - store, err = dbutil.NewFromConfig("", dbutil.Config{ - PoolConfig: dbutil.PoolConfig{ - Type: "sqlite3-fk-wal", - URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)), - MaxOpenConns: 5, - MaxIdleConns: 1, - }, - }, nil) - require.NoError(t, err) - } - cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store) - require.NoError(t, err) - client.Crypto = cryptoHelper - - err = cryptoHelper.Init(ctx) - require.NoError(t, err) - - machineLog := globallog.Logger.With(). - Stringer("my_user_id", userID). - Stringer("my_device_id", deviceID). - Logger() - cryptoHelper.Machine().Log = &machineLog - - err = cryptoHelper.Machine().ShareKeys(ctx, 50) - require.NoError(t, err) - - return client, cryptoHelper.Machine().CryptoStore -} - -func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) { - t.Helper() - - for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { - client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) - ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] - } -} diff --git a/pushrules/action.go b/pushrules/action.go index b5a884b2..9838e88b 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value = val["value"] + action.Value, _ = val["value"] } } return nil diff --git a/pushrules/action_test.go b/pushrules/action_test.go index 3c0aa168..a8f68415 100644 --- a/pushrules/action_test.go +++ b/pushrules/action_test.go @@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`)) - assert.NoError(t, err) + assert.Nil(t, err) err = pa.UnmarshalJSON([]byte(`9001`)) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`"foo"`)) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, pushrules.PushActionType("foo"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`)) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, pushrules.ActionSetTweak, pa.Action) assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak) @@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data) } @@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, []byte(`"something else"`), data) } diff --git a/pushrules/condition.go b/pushrules/condition.go index caa717de..435178fb 100644 --- a/pushrules/condition.go +++ b/pushrules/condition.go @@ -15,10 +15,10 @@ import ( "unicode" "github.com/tidwall/gjson" - "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules/glob" ) // Room is an interface with the functions that are needed for processing room-specific push conditions @@ -27,11 +27,6 @@ type Room interface { GetMemberCount() int } -type PowerLevelfulRoom interface { - Room - GetPowerLevels() *event.PowerLevelsEventContent -} - // EventfulRoom is an extension of Room to support MSC3664. type EventfulRoom interface { Room @@ -43,12 +38,11 @@ type PushCondKind string // The allowed push condition kinds as specified in https://spec.matrix.org/v1.2/client-server-api/#conditions-1 const ( - KindEventMatch PushCondKind = "event_match" - KindContainsDisplayName PushCondKind = "contains_display_name" - KindRoomMemberCount PushCondKind = "room_member_count" - KindEventPropertyIs PushCondKind = "event_property_is" - KindEventPropertyContains PushCondKind = "event_property_contains" - KindSenderNotificationPermission PushCondKind = "sender_notification_permission" + KindEventMatch PushCondKind = "event_match" + KindContainsDisplayName PushCondKind = "contains_display_name" + KindRoomMemberCount PushCondKind = "room_member_count" + KindEventPropertyIs PushCondKind = "event_property_is" + KindEventPropertyContains PushCondKind = "event_property_contains" // MSC3664: https://github.com/matrix-org/matrix-spec-proposals/pull/3664 @@ -88,8 +82,6 @@ func (cond *PushCondition) Match(room Room, evt *event.Event) bool { return cond.matchDisplayName(room, evt) case KindRoomMemberCount: return cond.matchMemberCount(room) - case KindSenderNotificationPermission: - return cond.matchSenderNotificationPermission(room, evt.Sender, cond.Key) default: return false } @@ -227,11 +219,11 @@ func (cond *PushCondition) matchValue(evt *event.Event) bool { switch cond.Kind { case KindEventMatch, KindRelatedEventMatch, KindUnstableRelatedEventMatch: - pattern := glob.CompileWithImplicitContains(cond.Pattern) - if pattern == nil { + pattern, err := glob.Compile(cond.Pattern) + if err != nil { return false } - return pattern.Match(stringifyForPushCondition(val)) + return pattern.MatchString(stringifyForPushCondition(val)) case KindEventPropertyIs: return valueEquals(val, cond.Value) case KindEventPropertyContains: @@ -342,18 +334,3 @@ func (cond *PushCondition) matchMemberCount(room Room) bool { return false } } - -func (cond *PushCondition) matchSenderNotificationPermission(room Room, sender id.UserID, key string) bool { - if key != "room" { - return false - } - plRoom, ok := room.(PowerLevelfulRoom) - if !ok { - return false - } - pls := plRoom.GetPowerLevels() - if pls == nil { - return false - } - return pls.GetUserLevel(sender) >= pls.Notifications.Room() -} diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 37af3e34..0d3eaf7a 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,6 +102,14 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } +func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { + return &pushrules.PushCondition{ + Kind: pushrules.KindEventPropertyContains, + Key: key, + Value: value, + } +} + func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/pushrules/glob/LICENSE b/pushrules/glob/LICENSE new file mode 100644 index 00000000..cb00d952 --- /dev/null +++ b/pushrules/glob/LICENSE @@ -0,0 +1,22 @@ +Glob is licensed under the MIT "Expat" License: + +Copyright (c) 2016: Zachary Yedidia. + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pushrules/glob/README.md b/pushrules/glob/README.md new file mode 100644 index 00000000..e2e6c649 --- /dev/null +++ b/pushrules/glob/README.md @@ -0,0 +1,28 @@ +# String globbing in Go + +[![GoDoc](https://godoc.org/github.com/zyedidia/glob?status.svg)](http://godoc.org/github.com/zyedidia/glob) + +This package adds support for globs in Go. + +It simply converts glob expressions to regexps. I try to follow the standard defined [here](http://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_13). + +# Example + +```go +package main + +import "github.com/zyedidia/glob" + +func main() { + glob, err := glob.Compile("{*.go,*.c}") + if err != nil { + // Error + } + + glob.Match([]byte("test.c")) // true + glob.Match([]byte("hello.go")) // true + glob.Match([]byte("test.d")) // false +} +``` + +You can call all the same functions on a glob that you can call on a regexp. diff --git a/pushrules/glob/glob.go b/pushrules/glob/glob.go new file mode 100644 index 00000000..c270dbc5 --- /dev/null +++ b/pushrules/glob/glob.go @@ -0,0 +1,108 @@ +// Package glob provides objects for matching strings with globs +package glob + +import "regexp" + +// Glob is a wrapper of *regexp.Regexp. +// It should contain a glob expression compiled into a regular expression. +type Glob struct { + *regexp.Regexp +} + +// Compile a takes a glob expression as a string and transforms it +// into a *Glob object (which is really just a regular expression) +// Compile also returns a possible error. +func Compile(pattern string) (*Glob, error) { + r, err := globToRegex(pattern) + return &Glob{r}, err +} + +func globToRegex(glob string) (*regexp.Regexp, error) { + regex := "" + inGroup := 0 + inClass := 0 + firstIndexInClass := -1 + arr := []byte(glob) + + hasGlobCharacters := false + + for i := 0; i < len(arr); i++ { + ch := arr[i] + + switch ch { + case '\\': + i++ + if i >= len(arr) { + regex += "\\" + } else { + next := arr[i] + switch next { + case ',': + // Nothing + case 'Q', 'E': + regex += "\\\\" + default: + regex += "\\" + } + regex += string(next) + } + case '*': + if inClass == 0 { + regex += ".*" + } else { + regex += "*" + } + hasGlobCharacters = true + case '?': + if inClass == 0 { + regex += "." + } else { + regex += "?" + } + hasGlobCharacters = true + case '[': + inClass++ + firstIndexInClass = i + 1 + regex += "[" + hasGlobCharacters = true + case ']': + inClass-- + regex += "]" + case '.', '(', ')', '+', '|', '^', '$', '@', '%': + if inClass == 0 || (firstIndexInClass == i && ch == '^') { + regex += "\\" + } + regex += string(ch) + hasGlobCharacters = true + case '!': + if firstIndexInClass == i { + regex += "^" + } else { + regex += "!" + } + hasGlobCharacters = true + case '{': + inGroup++ + regex += "(" + hasGlobCharacters = true + case '}': + inGroup-- + regex += ")" + case ',': + if inGroup > 0 { + regex += "|" + hasGlobCharacters = true + } else { + regex += "," + } + default: + regex += string(ch) + } + } + + if hasGlobCharacters { + return regexp.Compile("^" + regex + "$") + } else { + return regexp.Compile(regex) + } +} diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go index a5a0f5e7..a531ca28 100644 --- a/pushrules/pushrules_test.go +++ b/pushrules/pushrules_test.go @@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) { }, } pushRuleset, err := pushrules.EventToPushRules(evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.NotNil(t, pushRuleset) assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{}) diff --git a/pushrules/rule.go b/pushrules/rule.go index cf659695..0f7436f3 100644 --- a/pushrules/rule.go +++ b/pushrules/rule.go @@ -8,14 +8,10 @@ package pushrules import ( "encoding/gob" - "regexp" - "strings" - - "go.mau.fi/util/exerrors" - "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules/glob" ) func init() { @@ -168,20 +164,13 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { + pattern, err := glob.Compile(rule.Pattern) + if err != nil { + return false + } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } - var buf strings.Builder - // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive - // and must match whole words, so wrap the converted glob in (?i) and \b. - buf.WriteString(`(?i)\b`) - // strings.Builder will never return errors - exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf)) - buf.WriteString(`\b`) - pattern, err := regexp.Compile(buf.String()) - if err != nil { - return false - } return pattern.MatchString(msg) } diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go index 7ff839a7..803c721e 100644 --- a/pushrules/rule_test.go +++ b/pushrules/rule_test.go @@ -186,34 +186,6 @@ func TestPushRule_Match_Content(t *testing.T) { assert.True(t, rule.Match(blankTestRoom, evt)) } -func TestPushRule_Match_WordBoundary(t *testing.T) { - rule := &pushrules.PushRule{ - Type: pushrules.ContentRule, - Enabled: true, - Pattern: "test", - } - - evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ - MsgType: event.MsgEmote, - Body: "is testing pushrules", - }) - assert.False(t, rule.Match(blankTestRoom, evt)) -} - -func TestPushRule_Match_CaseInsensitive(t *testing.T) { - rule := &pushrules.PushRule{ - Type: pushrules.ContentRule, - Enabled: true, - Pattern: "test", - } - - evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ - MsgType: event.MsgEmote, - Body: "is TeSt-InG pushrules", - }) - assert.True(t, rule.Match(blankTestRoom, evt)) -} - func TestPushRule_Match_Content_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, diff --git a/pushrules/ruleset.go b/pushrules/ruleset.go index c42d4799..609997b4 100644 --- a/pushrules/ruleset.go +++ b/pushrules/ruleset.go @@ -68,9 +68,6 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) { var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}} func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) { - if rs == nil { - return nil - } // Add push rule collections to array in priority order arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride} // Loop until one of the push rule collections matches the room/event combo. diff --git a/requests.go b/requests.go index cc8b7266..f91aaa79 100644 --- a/requests.go +++ b/requests.go @@ -2,9 +2,7 @@ package mautrix import ( "encoding/json" - "fmt" "strconv" - "time" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" @@ -40,40 +38,20 @@ const ( type Direction rune -func (d Direction) MarshalJSON() ([]byte, error) { - return json.Marshal(string(d)) -} - -func (d *Direction) UnmarshalJSON(data []byte) error { - var str string - if err := json.Unmarshal(data, &str); err != nil { - return err - } - switch str { - case "f": - *d = DirectionForward - case "b": - *d = DirectionBackward - default: - return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str) - } - return nil -} - const ( DirectionForward Direction = 'f' DirectionBackward Direction = 'b' ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register -type ReqRegister[UIAType any] struct { +type ReqRegister struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` RefreshToken bool `json:"refresh_token,omitempty"` - Auth UIAType `json:"auth,omitempty"` + Auth interface{} `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions @@ -105,7 +83,6 @@ type ReqLogin struct { Token string `json:"token,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` - RefreshToken bool `json:"refresh_token,omitempty"` // Whether or not the returned credentials should be stored in the Client StoreCredentials bool `json:"-"` @@ -113,10 +90,6 @@ type ReqLogin struct { StoreHomeserverURL bool `json:"-"` } -type ReqPutDevice struct { - DisplayName string `json:"display_name,omitempty"` -} - type ReqUIAuthFallback struct { Session string `json:"session"` User string `json:"user"` @@ -141,17 +114,14 @@ type ReqCreateRoom struct { InitialState []*event.Event `json:"initial_state,omitempty"` Preset string `json:"preset,omitempty"` IsDirect bool `json:"is_direct,omitempty"` - RoomVersion id.RoomVersion `json:"room_version,omitempty"` + RoomVersion string `json:"room_version,omitempty"` PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` - MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"` BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"` BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"` - BeeperBridgeName string `json:"com.beeper.bridge_name,omitempty"` - BeeperBridgeAccountID string `json:"com.beeper.bridge_account_id,omitempty"` } // ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid @@ -161,37 +131,12 @@ type ReqRedact struct { Extra map[string]interface{} } -type ReqRedactUser struct { - Reason string `json:"reason"` - Limit int `json:"-"` -} - type ReqMembers struct { At string `json:"at"` Membership event.Membership `json:"membership,omitempty"` NotMembership event.Membership `json:"not_membership,omitempty"` } -type ReqJoinRoom struct { - Via []string `json:"-"` - Reason string `json:"reason,omitempty"` - ThirdPartySigned any `json:"third_party_signed,omitempty"` -} - -type ReqKnockRoom struct { - Via []string `json:"-"` - Reason string `json:"reason,omitempty"` -} - -type ReqSearchUserDirectory struct { - SearchTerm string `json:"search_term"` - Limit int `json:"limit,omitempty"` -} - -type ReqMutualRooms struct { - From string `json:"-"` -} - // ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 // It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom type ReqInvite3PID struct { @@ -220,8 +165,6 @@ type ReqKickUser struct { type ReqBanUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` - - MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } // ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban @@ -237,8 +180,7 @@ type ReqTyping struct { } type ReqPresence struct { - Presence event.Presence `json:"presence"` - StatusMsg string `json:"status_msg,omitempty"` + Presence event.Presence `json:"presence"` } type ReqAliasCreate struct { @@ -283,7 +225,7 @@ func (otk *OneTimeKey) MarshalJSON() ([]byte, error) { type ReqUploadKeys struct { DeviceKeys *DeviceKeys `json:"device_keys,omitempty"` - OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys,omitempty"` + OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys"` } type ReqKeysSignatures struct { @@ -320,11 +262,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { return "" } -type UploadCrossSigningKeysReq[UIAType any] struct { +type UploadCrossSigningKeysReq struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` - Auth UIAType `json:"auth,omitempty"` + Auth interface{} `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string @@ -366,40 +308,20 @@ type ReqSendToDevice struct { Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"` } -type ReqSendEvent struct { - Timestamp int64 - TransactionID string - UnstableDelay time.Duration - UnstableStickyDuration time.Duration - DontEncrypt bool - MeowEventID id.EventID -} - -type ReqDelayedEvents struct { - DelayID id.DelayID `json:"-"` - Status event.DelayStatus `json:"-"` - NextBatch string `json:"-"` -} - -type ReqUpdateDelayedEvent struct { - DelayID id.DelayID `json:"-"` - Action event.DelayAction `json:"action"` -} - // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid type ReqDeviceInfo struct { DisplayName string `json:"display_name,omitempty"` } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid -type ReqDeleteDevice[UIAType any] struct { - Auth UIAType `json:"auth,omitempty"` +type ReqDeleteDevice struct { + Auth interface{} `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices -type ReqDeleteDevices[UIAType any] struct { +type ReqDeleteDevices struct { Devices []id.DeviceID `json:"devices"` - Auth UIAType `json:"auth,omitempty"` + Auth interface{} `json:"auth,omitempty"` } type ReqPutPushRule struct { @@ -411,6 +333,18 @@ type ReqPutPushRule struct { Pattern string `json:"pattern"` } +// Deprecated: MSC2716 was abandoned +type ReqBatchSend struct { + PrevEventID id.EventID `json:"-"` + BatchID id.BatchID `json:"-"` + + BeeperNewMessages bool `json:"-"` + BeeperMarkReadBy id.UserID `json:"-"` + + StateEventsAtStart []*event.Event `json:"state_events_at_start"` + Events []*event.Event `json:"events"` +} + type ReqBeeperBatchSend struct { // ForwardIfNoMessages should be set to true if the batch should be forward // backfilled if there are no messages currently in the room. @@ -446,33 +380,6 @@ type ReqSendReceipt struct { ThreadID string `json:"thread_id,omitempty"` } -type ReqPublicRooms struct { - IncludeAllNetworks bool - Limit int - Since string - ThirdPartyInstanceID string -} - -func (req *ReqPublicRooms) Query() map[string]string { - query := map[string]string{} - if req == nil { - return query - } - if req.IncludeAllNetworks { - query["include_all_networks"] = "true" - } - if req.Limit > 0 { - query["limit"] = strconv.Itoa(req.Limit) - } - if req.Since != "" { - query["since"] = req.Since - } - if req.ThirdPartyInstanceID != "" { - query["third_party_instance_id"] = req.ThirdPartyInstanceID - } - return query -} - // ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // As it's a GET method, there is no JSON body, so this is only query parameters. @@ -560,59 +467,3 @@ type ReqKeyBackupData struct { IsVerified bool `json:"is_verified"` SessionData json.RawMessage `json:"session_data"` } - -type ReqReport struct { - Reason string `json:"reason,omitempty"` - Score int `json:"score,omitempty"` -} - -type ReqGetRelations struct { - RelationType event.RelationType - EventType event.Type - - Dir Direction - From string - To string - Limit int - Recurse bool -} - -func (rgr *ReqGetRelations) PathSuffix() ClientURLPath { - if rgr.RelationType != "" { - if rgr.EventType.Type != "" { - return ClientURLPath{rgr.RelationType, rgr.EventType.Type} - } - return ClientURLPath{rgr.RelationType} - } - return ClientURLPath{} -} - -func (rgr *ReqGetRelations) Query() map[string]string { - query := map[string]string{} - if rgr.Dir != 0 { - query["dir"] = string(rgr.Dir) - } - if rgr.From != "" { - query["from"] = rgr.From - } - if rgr.To != "" { - query["to"] = rgr.To - } - if rgr.Limit > 0 { - query["limit"] = strconv.Itoa(rgr.Limit) - } - if rgr.Recurse { - query["recurse"] = "true" - } - return query -} - -// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type ReqSuspend struct { - Suspended bool `json:"suspended"` -} - -// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type ReqLocked struct { - Locked bool `json:"locked"` -} diff --git a/responses.go b/responses.go index 4fbe1fbc..9e5fd0aa 100644 --- a/responses.go +++ b/responses.go @@ -4,16 +4,13 @@ import ( "bytes" "encoding/json" "fmt" - "maps" "reflect" - "slices" "strconv" "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -35,11 +32,6 @@ type RespJoinRoom struct { RoomID id.RoomID `json:"room_id"` } -// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias -type RespKnockRoom struct { - RoomID id.RoomID `json:"room_id"` -} - // RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave type RespLeaveRoom struct{} @@ -105,29 +97,6 @@ type RespContext struct { // RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid type RespSendEvent struct { EventID id.EventID `json:"event_id"` - - UnstableDelayID id.DelayID `json:"delay_id,omitempty"` -} - -type RespUpdateDelayedEvent struct{} - -type RespDelayedEvents struct { - Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"` - Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"` - NextBatch string `json:"next_batch,omitempty"` - - // Deprecated: Synapse implementation still returns this - DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"` - // Deprecated: Synapse implementation still returns this - FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"` -} - -type RespRedactUserEvents struct { - IsMoreEvents bool `json:"is_more_events"` - RedactedEvents struct { - Total int `json:"total"` - SoftFailed int `json:"soft_failed"` - } `json:"redacted_events"` } // RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config @@ -142,14 +111,10 @@ type RespMediaUpload struct { // RespCreateMXC is the JSON response for https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create type RespCreateMXC struct { - ContentURI id.ContentURI `json:"content_uri"` - UnusedExpiresAt jsontime.UnixMilli `json:"unused_expires_at,omitempty"` + ContentURI id.ContentURI `json:"content_uri"` + UnusedExpiresAt int `json:"unused_expires_at,omitempty"` UnstableUploadURL string `json:"com.beeper.msc3870.upload_url,omitempty"` - - // Beeper extensions for uploading unique media only once - BeeperUniqueID string `json:"com.beeper.unique_id,omitempty"` - BeeperCompletedAt jsontime.UnixMilli `json:"com.beeper.completed_at,omitempty"` } // RespPreviewURL is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url @@ -186,89 +151,8 @@ type RespUserDisplayName struct { } type RespUserProfile struct { - DisplayName string `json:"displayname,omitempty"` - AvatarURL id.ContentURI `json:"avatar_url,omitempty"` - Extra map[string]any `json:"-"` -} - -type marshalableUserProfile RespUserProfile - -func (r *RespUserProfile) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, &r.Extra) - if err != nil { - return err - } - r.DisplayName, _ = r.Extra["displayname"].(string) - avatarURL, _ := r.Extra["avatar_url"].(string) - if avatarURL != "" { - r.AvatarURL, _ = id.ParseContentURI(avatarURL) - } - delete(r.Extra, "displayname") - delete(r.Extra, "avatar_url") - return nil -} - -func (r *RespUserProfile) MarshalJSON() ([]byte, error) { - if len(r.Extra) == 0 { - return json.Marshal((*marshalableUserProfile)(r)) - } - marshalMap := maps.Clone(r.Extra) - if r.DisplayName != "" { - marshalMap["displayname"] = r.DisplayName - } else { - delete(marshalMap, "displayname") - } - if !r.AvatarURL.IsEmpty() { - marshalMap["avatar_url"] = r.AvatarURL.String() - } else { - delete(marshalMap, "avatar_url") - } - return json.Marshal(marshalMap) -} - -type RespSearchUserDirectory struct { - Limited bool `json:"limited"` - Results []*UserDirectoryEntry `json:"results"` -} - -type UserDirectoryEntry struct { - RespUserProfile - UserID id.UserID `json:"user_id"` -} - -func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error { - err := r.RespUserProfile.UnmarshalJSON(data) - if err != nil { - return err - } - userIDStr, _ := r.Extra["user_id"].(string) - r.UserID = id.UserID(userIDStr) - delete(r.Extra, "user_id") - return nil -} - -func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { - if r.Extra == nil { - r.Extra = make(map[string]any) - } - r.Extra["user_id"] = r.UserID.String() - return r.RespUserProfile.MarshalJSON() -} - -type RespMutualRooms struct { - Joined []id.RoomID `json:"joined"` - NextBatch string `json:"next_batch,omitempty"` - Count int `json:"count,omitempty"` -} - -type RespRoomSummary struct { - PublicRoomInfo - - Membership event.Membership `json:"membership,omitempty"` - - UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` - UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` - UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"` + DisplayName string `json:"displayname"` + AvatarURL id.ContentURI `json:"avatar_url"` } // RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable @@ -319,9 +203,6 @@ type RespLogin struct { DeviceID id.DeviceID `json:"device_id"` UserID id.UserID `json:"user_id"` WellKnown *ClientWellKnown `json:"well_known,omitempty"` - - RefreshToken string `json:"refresh_token,omitempty"` - ExpiresInMS int64 `json:"expires_in_ms,omitempty"` } // RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout @@ -342,24 +223,6 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } -func (lls *LazyLoadSummary) MemberCount() int { - if lls == nil { - return 0 - } - return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount) -} - -func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { - if lls == other { - return true - } else if lls == nil || other == nil { - return false - } - return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) && - ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) && - slices.Equal(lls.Heroes, other.Heroes) -} - type SyncEventsList struct { Events []*event.Event `json:"events,omitempty"` } @@ -455,7 +318,6 @@ type BeeperInboxPreviewEvent struct { type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` - StateAfter *SyncEventsList `json:"state_after,omitempty"` Timeline SyncTimeline `json:"timeline"` Ephemeral SyncEventsList `json:"ephemeral"` AccountData SyncEventsList `json:"account_data"` @@ -481,7 +343,16 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) { } type SyncInvitedRoom struct { - State SyncEventsList `json:"invite_state"` + Summary LazyLoadSummary `json:"summary"` + State SyncEventsList `json:"invite_state"` +} + +type marshalableSyncInvitedRoom SyncInvitedRoom + +var syncInvitedRoomPathsToDelete = []string{"summary"} + +func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) { + return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete) } type SyncKnockedRoom struct { @@ -546,19 +417,29 @@ type RespDeviceInfo struct { LastSeenTS int64 `json:"last_seen_ts"` } +// Deprecated: MSC2716 was abandoned +type RespBatchSend struct { + StateEventIDs []id.EventID `json:"state_event_ids"` + EventIDs []id.EventID `json:"event_ids"` + + InsertionEventID id.EventID `json:"insertion_event_id"` + BatchEventID id.EventID `json:"batch_event_id"` + BaseInsertionEventID id.EventID `json:"base_insertion_event_id"` + + NextBatchID id.BatchID `json:"next_batch_id"` +} + type RespBeeperBatchSend struct { EventIDs []id.EventID `json:"event_ids"` } // RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities type RespCapabilities struct { - RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` - ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` - SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` - SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` - ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` - GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` - UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"` + RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` + ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` + SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` + SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` + ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` Custom map[string]interface{} `json:"-"` } @@ -667,44 +548,29 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool { return available } -type CapUnstableAccountModeration struct { - Suspend bool `json:"suspend"` - Lock bool `json:"lock"` -} - -type RespPublicRooms struct { - Chunk []*PublicRoomInfo `json:"chunk"` - NextBatch string `json:"next_batch,omitempty"` - PrevBatch string `json:"prev_batch,omitempty"` - TotalRoomCountEstimate int `json:"total_room_count_estimate"` -} - -type PublicRoomInfo struct { - RoomID id.RoomID `json:"room_id"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` - GuestCanJoin bool `json:"guest_can_join"` - JoinRule event.JoinRule `json:"join_rule,omitempty"` - Name string `json:"name,omitempty"` - NumJoinedMembers int `json:"num_joined_members"` - RoomType event.RoomType `json:"room_type"` - Topic string `json:"topic,omitempty"` - WorldReadable bool `json:"world_readable"` - - RoomVersion id.RoomVersion `json:"room_version,omitempty"` - Encryption id.Algorithm `json:"encryption,omitempty"` - AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` -} - // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy type RespHierarchy struct { - NextBatch string `json:"next_batch,omitempty"` - Rooms []*ChildRoomsChunk `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` + Rooms []ChildRoomsChunk `json:"rooms"` } type ChildRoomsChunk struct { - PublicRoomInfo - ChildrenState []*event.Event `json:"children_state"` + AvatarURL id.ContentURI `json:"avatar_url,omitempty"` + CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` + ChildrenState []StrippedStateWithTime `json:"children_state"` + GuestCanJoin bool `json:"guest_can_join"` + JoinRule event.JoinRule `json:"join_rule,omitempty"` + Name string `json:"name,omitempty"` + NumJoinedMembers int `json:"num_joined_members"` + RoomID id.RoomID `json:"room_id"` + RoomType event.RoomType `json:"room_type"` + Topic string `json:"topic,omitempty"` + WorldReadble bool `json:"world_readable"` +} + +type StrippedStateWithTime struct { + event.StrippedState + Timestamp jsontime.UnixMilli `json:"origin_server_ts"` } type RespAppservicePing struct { @@ -753,47 +619,3 @@ type RespRoomKeysUpdate struct { Count int `json:"count"` ETag string `json:"etag"` } - -type RespOpenIDToken struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - MatrixServerName string `json:"matrix_server_name"` - TokenType string `json:"token_type"` // Always "Bearer" -} - -type RespGetRelations struct { - Chunk []*event.Event `json:"chunk"` - NextBatch string `json:"next_batch,omitempty"` - PrevBatch string `json:"prev_batch,omitempty"` - RecursionDepth int `json:"recursion_depth,omitempty"` -} - -// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type RespSuspended struct { - Suspended bool `json:"suspended"` -} - -// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type RespLocked struct { - Locked bool `json:"locked"` -} - -type ConnectionInfo struct { - IP string `json:"ip,omitempty"` - LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"` - UserAgent string `json:"user_agent,omitempty"` -} - -type SessionInfo struct { - Connections []ConnectionInfo `json:"connections,omitempty"` -} - -type DeviceInfo struct { - Sessions []SessionInfo `json:"sessions,omitempty"` -} - -// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid -type RespWhoIs struct { - UserID id.UserID `json:"user_id,omitempty"` - Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"` -} diff --git a/responses_test.go b/responses_test.go index 73d82635..b23d85ad 100644 --- a/responses_test.go +++ b/responses_test.go @@ -8,6 +8,7 @@ package mautrix_test import ( "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -85,6 +86,7 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) { var caps mautrix.RespCapabilities err := json.Unmarshal([]byte(sampleData), &caps) require.NoError(t, err) + fmt.Println(caps) require.NotNil(t, caps.RoomVersions) assert.Equal(t, "9", caps.RoomVersions.Default) diff --git a/room.go b/room.go index 4292bff5..c3ddb7e6 100644 --- a/room.go +++ b/room.go @@ -5,6 +5,8 @@ import ( "maunium.net/go/mautrix/id" ) +type RoomStateMap = map[event.Type]map[string]*event.Event + // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -23,8 +25,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap := room.State[eventType] - evt := stateEventMap[stateKey] + stateEventMap, _ := room.State[eventType] + evt, _ := stateEventMap[stateKey] return evt } diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 11957dfa..0e5c4184 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -19,7 +19,6 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/confusable" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -62,9 +61,6 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) } func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error { - if userID == "" { - return fmt.Errorf("user ID is empty") - } _, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) return err } @@ -88,11 +84,14 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ",")) } rows, err := store.Query(ctx, query, args...) + if err != nil { + return nil, err + } members := make(map[id.UserID]*event.MemberEventContent) - return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) { + return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) { err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL) return - }, err).Iter(func(m Member) (bool, error) { + }).Iter(func(m Member) (bool, error) { members[m.UserID] = &m.MemberEventContent return true, nil }) @@ -159,7 +158,10 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI ` } rows, err := store.Query(ctx, query, userID) - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() + if err != nil { + return nil, err + } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() } func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { @@ -185,11 +187,6 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, } func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } else if userID == "" { - return fmt.Errorf("user ID is empty") - } _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '') ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership @@ -197,42 +194,21 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, return err } -const insertUserProfileQuery = ` - INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (room_id, user_id) DO UPDATE - SET membership=excluded.membership, - displayname=excluded.displayname, - avatar_url=excluded.avatar_url, - name_skeleton=excluded.name_skeleton -` - -type userProfileRow struct { - UserID id.UserID - Membership event.Membership - Displayname string - AvatarURL id.ContentURIString - NameSkeleton []byte -} - -func (u *userProfileRow) GetMassInsertValues() [5]any { - return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton} -} - -var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)") - func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } else if userID == "" { - return fmt.Errorf("user ID is empty") - } var nameSkeleton []byte if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { nameSkeletonArr := confusable.SkeletonHash(member.Displayname) nameSkeleton = nameSkeletonArr[:] } - _, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) + _, err := store.Exec(ctx, ` + INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id) DO UPDATE + SET membership=excluded.membership, + displayname=excluded.displayname, + avatar_url=excluded.avatar_url, + name_skeleton=excluded.name_skeleton + `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) return err } @@ -245,53 +221,6 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } -const userProfileMassInsertBatchSize = 500 - -func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } - return store.DoTxn(ctx, nil, func(ctx context.Context) error { - err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...) - if err != nil { - return fmt.Errorf("failed to clear cached members: %w", err) - } - rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize)) - for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) { - rows = rows[:0] - for _, evt := range evtsChunk { - content, ok := evt.Content.Parsed.(*event.MemberEventContent) - if !ok { - continue - } - row := &userProfileRow{ - UserID: id.UserID(*evt.StateKey), - Membership: content.Membership, - Displayname: content.Displayname, - AvatarURL: content.AvatarURL, - } - if !store.DisableNameDisambiguation && len(content.Displayname) > 0 { - nameSkeletonArr := confusable.SkeletonHash(content.Displayname) - row.NameSkeleton = nameSkeletonArr[:] - } - rows = append(rows, row) - } - query, args := userProfileMassInserter.Build([1]any{roomID}, rows) - _, err = store.Exec(ctx, query, args...) - if err != nil { - return fmt.Errorf("failed to insert members: %w", err) - } - } - if len(onlyMemberships) == 0 { - err = store.MarkMembersFetched(ctx, roomID) - if err != nil { - return fmt.Errorf("failed to mark members as fetched: %w", err) - } - } - return nil - }) -} - func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) @@ -305,57 +234,10 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ",")) } _, err := store.Exec(ctx, query, params...) - if err != nil { - return err - } - _, err = store.Exec(ctx, "UPDATE mx_room_state SET members_fetched=false WHERE room_id=$1", roomID) return err } -func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) { - err = store.QueryRow(ctx, "SELECT COALESCE(members_fetched, false) FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - -func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } - _, err := store.Exec(ctx, ` - INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true) - ON CONFLICT (room_id) DO UPDATE SET members_fetched=true - `, roomID) - return err -} - -type userAndMembership struct { - UserID id.UserID - event.MemberEventContent -} - -func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - rows, err := store.Query(ctx, "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID) - if err != nil { - return nil, err - } - output := make(map[id.UserID]*event.MemberEventContent) - err = dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (res userAndMembership, err error) { - err = row.Scan(&res.UserID, &res.Membership, &res.Displayname, &res.AvatarURL) - return - }, err).Iter(func(member userAndMembership) (bool, error) { - output[member.UserID] = &member.MemberEventContent - return true, nil - }) - return output, err -} - func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } contentBytes, err := json.Marshal(content) if err != nil { return fmt.Errorf("failed to marshal content JSON: %w", err) @@ -370,7 +252,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). Scan(&data) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -393,9 +275,6 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) ( } func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels @@ -404,92 +283,89 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID } func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { - levels = &event.PowerLevelsEventContent{} err = store. - QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID). - Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent}) + QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). + Scan(&dbutil.JSON{Data: &levels}) if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { - return nil, err - } - if levels.CreateEvent != nil { - err = levels.CreateEvent.Content.ParseRaw(event.StateCreate) + err = nil } return } func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err + if store.Dialect == dbutil.Postgres { + var powerLevel int + err := store. + QueryRow(ctx, ` + SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) + FROM mx_room_state WHERE room_id=$1 + `, roomID, userID). + Scan(&powerLevel) + return powerLevel, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err + } + return levels.GetUserLevel(userID), nil } - return levels.GetUserLevel(userID), nil } func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err + if store.Dialect == dbutil.Postgres { + defaultType := "events_default" + defaultValue := 0 + if eventType.IsState() { + defaultType = "state_default" + defaultValue = 50 + } + var powerLevel int + err := store. + QueryRow(ctx, ` + SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) + FROM mx_room_state WHERE room_id=$1 + `, roomID, eventType.Type, defaultType, defaultValue). + Scan(&powerLevel) + if errors.Is(err, sql.ErrNoRows) { + err = nil + powerLevel = defaultValue + } + return powerLevel, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err + } + return levels.GetEventLevel(eventType), nil } - return levels.GetEventLevel(eventType), nil } func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return false, err + if store.Dialect == dbutil.Postgres { + defaultType := "events_default" + defaultValue := 0 + if eventType.IsState() { + defaultType = "state_default" + defaultValue = 50 + } + var hasPower bool + err := store. + QueryRow(ctx, `SELECT + COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) + >= + COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) + FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). + Scan(&hasPower) + if errors.Is(err, sql.ErrNoRows) { + err = nil + hasPower = defaultValue == 0 + } + return hasPower, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return false, err + } + return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil } - return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil -} - -func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error { - if evt.Type != event.StateCreate { - return fmt.Errorf("invalid event type for create event: %s", evt.Type) - } else if evt.RoomID == "" { - return fmt.Errorf("room ID is empty") - } - _, err := store.Exec(ctx, ` - INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) - ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event - `, evt.RoomID, dbutil.JSON{Data: evt}) - return err -} - -func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) { - err = store. - QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID). - Scan(&dbutil.JSON{Data: &evt}) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } else if err != nil { - return nil, err - } - if evt != nil { - err = evt.Content.ParseRaw(event.StateCreate) - } - return -} - -func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } - _, err := store.Exec(ctx, ` - INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2) - ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules - `, roomID, dbutil.JSON{Data: rules}) - return err -} - -func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) { - levels = &event.JoinRulesEventContent{} - err = store. - QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID). - Scan(&dbutil.JSON{Data: &levels}) - if errors.Is(err, sql.ErrNoRows) { - levels = nil - err = nil - } - return } diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index 4679f1c6..b2bb2ae6 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v10 (compatible with v3+): Latest revision +-- v0 -> v6 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -8,11 +8,11 @@ CREATE TABLE mx_registrations ( CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock'); CREATE TABLE mx_user_profile ( - room_id TEXT, - user_id TEXT, - membership membership NOT NULL, - displayname TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', + room_id TEXT, + user_id TEXT, + membership membership NOT NULL, + displayname TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', name_skeleton bytea, @@ -23,10 +23,7 @@ CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, members CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); CREATE TABLE mx_room_state ( - room_id TEXT PRIMARY KEY, - power_levels jsonb, - encryption jsonb, - create_event jsonb, - join_rules jsonb, - members_fetched BOOLEAN NOT NULL DEFAULT false + room_id TEXT PRIMARY KEY, + power_levels jsonb, + encryption jsonb ); diff --git a/sqlstatestore/v07-full-member-flag.sql b/sqlstatestore/v07-full-member-flag.sql deleted file mode 100644 index 32f2ef6c..00000000 --- a/sqlstatestore/v07-full-member-flag.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v7 (compatible with v3+): Add flag for whether the full member list has been fetched -ALTER TABLE mx_room_state ADD COLUMN members_fetched BOOLEAN NOT NULL DEFAULT false; diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql deleted file mode 100644 index 9f1b55c9..00000000 --- a/sqlstatestore/v08-create-event.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v8 (compatible with v3+): Add create event to room state table -ALTER TABLE mx_room_state ADD COLUMN create_event jsonb; diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql deleted file mode 100644 index ca951068..00000000 --- a/sqlstatestore/v09-clear-empty-room-ids.sql +++ /dev/null @@ -1,3 +0,0 @@ --- v9 (compatible with v3+): Clear invalid rows -DELETE FROM mx_room_state WHERE room_id=''; -DELETE FROM mx_user_profile WHERE room_id='' OR user_id=''; diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql deleted file mode 100644 index 3074c46a..00000000 --- a/sqlstatestore/v10-join-rules.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v10 (compatible with v3+): Add join rules to room state table -ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb; diff --git a/statestore.go b/statestore.go index 2bd498dd..fd8f81e5 100644 --- a/statestore.go +++ b/statestore.go @@ -8,7 +8,6 @@ package mautrix import ( "context" - "maps" "sync" "github.com/rs/zerolog" @@ -29,21 +28,10 @@ type StateStore interface { SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error - ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) - SetCreate(ctx context.Context, evt *event.Event) error - GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) - - GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) - SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error - - HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) - MarkMembersFetched(ctx context.Context, roomID id.RoomID) error - GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) - SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) @@ -74,19 +62,6 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: err = store.SetEncryptionEvent(ctx, evt.RoomID, content) - case *event.CreateEventContent: - err = store.SetCreate(ctx, evt) - case *event.JoinRulesEventContent: - err = store.SetJoinRules(ctx, evt.RoomID, content) - default: - switch evt.Type { - case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: - zerolog.Ctx(ctx).Warn(). - Stringer("event_id", evt.ID). - Str("event_type", evt.Type.Type). - Type("content_type", evt.Content.Parsed). - Msg("Got known event type with unknown content type in UpdateStateStore") - } } if err != nil { zerolog.Ctx(ctx).Warn().Err(err). @@ -106,30 +81,23 @@ func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) } type MemoryStateStore struct { - Registrations map[id.UserID]bool `json:"registrations"` - Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` - MembersFetched map[id.RoomID]bool `json:"members_fetched"` - PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` - Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` - Create map[id.RoomID]*event.Event `json:"create"` - JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"` + Registrations map[id.UserID]bool `json:"registrations"` + Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` + PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` + Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` registrationsLock sync.RWMutex membersLock sync.RWMutex powerLevelsLock sync.RWMutex encryptionLock sync.RWMutex - joinRulesLock sync.RWMutex } func NewMemoryStateStore() StateStore { return &MemoryStateStore{ - Registrations: make(map[id.UserID]bool), - Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), - MembersFetched: make(map[id.RoomID]bool), - PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), - Encryption: make(map[id.RoomID]*event.EncryptionEventContent), - Create: make(map[id.RoomID]*event.Event), - JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent), + Registrations: make(map[id.UserID]bool), + Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), + PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), + Encryption: make(map[id.RoomID]*event.EncryptionEventContent), } } @@ -269,40 +237,9 @@ func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.R } } } - store.MembersFetched[roomID] = false return nil } -func (store *MemoryStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { - store.membersLock.RLock() - defer store.membersLock.RUnlock() - return store.MembersFetched[roomID], nil -} - -func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { - store.membersLock.Lock() - defer store.membersLock.Unlock() - store.MembersFetched[roomID] = true - return nil -} - -func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { - _ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...) - for _, evt := range evts { - UpdateStateStore(ctx, store, evt) - } - if len(onlyMemberships) == 0 { - _ = store.MarkMembersFetched(ctx, roomID) - } - return nil -} - -func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - store.membersLock.RLock() - defer store.membersLock.RUnlock() - return maps.Clone(store.Members[roomID]), nil -} - func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels @@ -313,9 +250,6 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] - if levels != nil && levels.CreateEvent == nil { - levels.CreateEvent = store.Create[roomID] - } store.powerLevelsLock.RUnlock() return } @@ -332,23 +266,6 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil } -func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error { - store.powerLevelsLock.Lock() - store.Create[evt.RoomID] = evt - if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil { - pls.CreateEvent = evt - } - store.powerLevelsLock.Unlock() - return nil -} - -func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) { - store.powerLevelsLock.RLock() - evt := store.Create[roomID] - store.powerLevelsLock.RUnlock() - return evt, nil -} - func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { store.encryptionLock.Lock() store.Encryption[roomID] = content @@ -362,19 +279,6 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R return store.Encryption[roomID], nil } -func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error { - store.joinRulesLock.Lock() - store.JoinRules[roomID] = content - store.joinRulesLock.Unlock() - return nil -} - -func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) { - store.joinRulesLock.RLock() - defer store.joinRulesLock.RUnlock() - return store.JoinRules[roomID], nil -} - func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err diff --git a/synapseadmin/client.go b/synapseadmin/client.go index 6925ca7d..775b4b13 100644 --- a/synapseadmin/client.go +++ b/synapseadmin/client.go @@ -14,9 +14,9 @@ import ( // // https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html type Client struct { - Client *mautrix.Client + *mautrix.Client } func (cli *Client) BuildAdminURL(path ...any) string { - return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path)) + return cli.BuildURL(mautrix.SynapseAdminURLPath(path)) } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index 05e0729a..641f9b56 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,7 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp) + _, err := cli.MakeRequest(ctx, http.MethodGet, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), nil, &resp) if err != nil { return "", err } @@ -93,7 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), &req, &resp) if err != nil { return nil, err } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index 0925b748..6c072e23 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,17 +75,12 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) - _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + var reqURL string + reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } -func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) { - reqURL := cli.BuildAdminURL("v1", "rooms", roomID) - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) - return -} - type RespRoomMessages = mautrix.RespMessages // RoomMessages returns a list of messages in a room. @@ -109,14 +104,13 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to if limit != 0 { query["limit"] = strconv.Itoa(limit) } - urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return resp, err } type ReqDeleteRoom struct { Purge bool `json:"purge,omitempty"` - ForcePurge bool `json:"force_purge,omitempty"` Block bool `json:"block,omitempty"` Message string `json:"message,omitempty"` RoomName string `json:"room_name,omitempty"` @@ -127,19 +121,6 @@ type RespDeleteRoom struct { DeleteID string `json:"delete_id"` } -type RespDeleteRoomResult struct { - KickedUsers []id.UserID `json:"kicked_users,omitempty"` - FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"` - LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"` - NewRoomID id.RoomID `json:"new_room_id,omitempty"` -} - -type RespDeleteRoomStatus struct { - Status string `json:"status,omitempty"` - Error string `json:"error,omitempty"` - ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"` -} - // DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database. // // This calls the async version of the endpoint, which will return immediately and delete the room in the background. @@ -148,37 +129,10 @@ type RespDeleteRoomStatus struct { func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) { reqURL := cli.BuildAdminURL("v2", "rooms", roomID) var resp RespDeleteRoom - _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) + _, err := cli.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) return resp, err } -func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) { - reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID) - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) - return -} - -// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database. -// -// This calls the synchronous version of the endpoint, which will block until the room is deleted. -// -// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version -func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) { - reqURL := cli.BuildAdminURL("v1", "rooms", roomID) - httpClient := &http.Client{} - _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodDelete, - URL: reqURL, - RequestJSON: &req, - ResponseJSON: &resp, - MaxAttempts: 1, - // Use a fresh HTTP client without timeouts - Client: httpClient, - }) - httpClient.CloseIdleConnections() - return -} - type RespRoomsMembers struct { Members []id.UserID `json:"members"` Total int `json:"total"` @@ -190,7 +144,7 @@ type RespRoomsMembers struct { func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members") var resp RespRoomsMembers - _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -203,7 +157,7 @@ type ReqMakeRoomAdmin struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin") - _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -216,7 +170,7 @@ type ReqJoinUserToRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error { reqURL := cli.BuildAdminURL("v1", "join", roomID) - _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -229,7 +183,7 @@ type ReqBlockRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -245,6 +199,6 @@ type RoomsBlockResponse struct { func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) { var resp RoomsBlockResponse reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index b1de55b6..9cbb17e4 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -32,7 +32,7 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -43,8 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { - u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp) + u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } @@ -65,7 +65,7 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) return } @@ -89,7 +89,7 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) return } @@ -102,20 +102,7 @@ type ReqDeleteUser struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error { reqURL := cli.BuildAdminURL("v1", "deactivate", userID) - _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) - return err -} - -type ReqSuspendUser struct { - Suspend bool `json:"suspend"` -} - -// SuspendAccount suspends or unsuspends a specific local user account. -// -// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account -func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error { - reqURL := cli.BuildAdminURL("v1", "suspend", userID) - _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -137,7 +124,7 @@ type ReqCreateOrModifyAccount struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error { reqURL := cli.BuildAdminURL("v2", "users", userID) - _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -153,7 +140,7 @@ type ReqSetRatelimit = RatelimitOverride // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error { reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit") - _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -163,7 +150,7 @@ type RespUserRatelimit = RatelimitOverride // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) { - _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) return } @@ -171,6 +158,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) { - _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) + _, err = cli.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) return } diff --git a/sync.go b/sync.go index 598df8e0..d4208404 100644 --- a/sync.go +++ b/sync.go @@ -90,7 +90,6 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) } }() - ctx = context.WithValue(ctx, SyncTokenContextKey, since) for _, listener := range s.syncListeners { if !listener(ctx, res, since) { @@ -98,38 +97,33 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc } } - s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false) - s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false) - s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false) + s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice) + s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence) + s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData) for roomID, roomData := range res.Rooms.Join { - if roomData.StateAfter == nil { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false) - } else { - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true) - s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false) - } - s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false) - s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral) + s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData) } for roomID, roomData := range res.Rooms.Invite { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState) } for roomID, roomData := range res.Rooms.Leave { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline) } return } -func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) { +func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) { for _, evt := range events { - s.processSyncEvent(ctx, roomID, evt, source, ignoreState) + s.processSyncEvent(ctx, roomID, evt, source) } } -func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) { +func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. @@ -155,7 +149,6 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, } evt.Mautrix.EventSource = source - evt.Mautrix.IgnoreState = ignoreState s.Dispatch(ctx, evt) } @@ -198,8 +191,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e } var defaultFilter = Filter{ - Room: &RoomFilter{ - Timeline: &FilterPart{ + Room: RoomFilter{ + Timeline: FilterPart{ Limit: 50, }, }, @@ -264,7 +257,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState) func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool { for _, meta := range resp.Rooms.Invite { - var inviteState []*event.Event + var inviteState []event.StrippedState var inviteEvt *event.Event for _, evt := range meta.State.Events { if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() { @@ -272,7 +265,12 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string } else { evt.Type.Class = event.StateEventType _ = evt.Content.ParseRaw(evt.Type) - inviteState = append(inviteState, evt) + inviteState = append(inviteState, event.StrippedState{ + Content: evt.Content, + Type: evt.Type, + StateKey: evt.GetStateKey(), + Sender: evt.Sender, + }) } } if inviteEvt != nil { diff --git a/url.go b/url.go index 91b3d49d..4646b442 100644 --- a/url.go +++ b/url.go @@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL { // BuildURL builds a URL with the Client's homeserver and appservice user ID set already. func (cli *Client) BuildURL(urlPath PrefixableURLPath) string { - return cli.BuildURLWithFullQuery(urlPath, nil) + return cli.BuildURLWithQuery(urlPath, nil) } // BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already. // This method also automatically prepends the client API prefix (/_matrix/client). func (cli *Client) BuildClientURL(urlPath ...any) string { - return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil) + return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil) } type PrefixableURLPath interface { @@ -97,30 +97,15 @@ func (saup SynapseAdminURLPath) FullPath() []any { // BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { - return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { - for k, v := range urlQuery { - q.Set(k, v) - } - }) -} - -// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver -// and appservice user ID set already. -func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string { - if cli == nil { - return "client is nil" - } hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...) query := hsURL.Query() if cli.SetAppServiceUserID { query.Set("user_id", string(cli.UserID)) } - if cli.SetAppServiceDeviceID && cli.DeviceID != "" { - query.Set("device_id", string(cli.DeviceID)) - query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID)) - } - if fn != nil { - fn(query) + if urlQuery != nil { + for k, v := range urlQuery { + query.Set(k, v) + } } hsURL.RawQuery = query.Encode() return hsURL.String() diff --git a/version.go b/version.go index f00bbf39..29c5eb46 100644 --- a/version.go +++ b/version.go @@ -4,11 +4,10 @@ import ( "fmt" "regexp" "runtime" - "runtime/debug" "strings" ) -const Version = "v0.26.3" +const Version = "v0.20.0" var GoModVersion = "" var Commit = "" @@ -16,20 +15,11 @@ var VersionWithCommit = Version var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go") +var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`) + func init() { - if GoModVersion == "" { - info, _ := debug.ReadBuildInfo() - if info != nil { - for _, mod := range info.Deps { - if mod.Path == "maunium.net/go/mautrix" { - GoModVersion = mod.Version - break - } - } - } - } if GoModVersion != "" { - match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion) + match := goModVersionRegex.FindStringSubmatch(GoModVersion) if match != nil { Commit = match[1] } diff --git a/versions.go b/versions.go index 61b2e4ea..60eb0f30 100644 --- a/versions.go +++ b/versions.go @@ -60,28 +60,16 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} - FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/} - FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} - FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} - FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} - FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/} - FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} - FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} - FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} - BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} - BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} - BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} - BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} - BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} - BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} - BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} - BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"} + BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} + BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} + BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} + BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} + BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { @@ -121,12 +109,6 @@ var ( SpecV19 = MustParseSpecVersion("v1.9") SpecV110 = MustParseSpecVersion("v1.10") SpecV111 = MustParseSpecVersion("v1.11") - SpecV112 = MustParseSpecVersion("v1.12") - SpecV113 = MustParseSpecVersion("v1.13") - SpecV114 = MustParseSpecVersion("v1.14") - SpecV115 = MustParseSpecVersion("v1.15") - SpecV116 = MustParseSpecVersion("v1.16") - SpecV117 = MustParseSpecVersion("v1.17") ) func (svf SpecVersionFormat) String() string {