diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 71c1988b..c0add220 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -10,12 +10,12 @@ jobs:
runs-on: ubuntu-latest
name: Lint (latest)
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
- go-version: "1.24"
+ go-version: "1.26"
cache: true
- name: Install libolm
@@ -24,6 +24,7 @@ jobs:
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
+ go install honnef.co/go/tools/cmd/staticcheck@latest
export PATH="$HOME/go/bin:$PATH"
- name: Run pre-commit
@@ -34,14 +35,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- go-version: ["1.23", "1.24"]
- name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, libolm)
+ go-version: ["1.25", "1.26"]
+ name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm)
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go ${{ matrix.go-version }}
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}
cache: true
@@ -60,28 +61,28 @@ jobs:
- name: Test
run: go test -json -v ./... 2>&1 | gotestfmt
+ - name: Test (jsonv2)
+ env:
+ GOEXPERIMENT: jsonv2
+ run: go test -json -v ./... 2>&1 | gotestfmt
+
build-goolm:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- go-version: ["1.23", "1.24"]
- name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, goolm)
+ go-version: ["1.25", "1.26"]
+ name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm)
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go ${{ matrix.go-version }}
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}
cache: true
- - name: Set up gotestfmt
- uses: GoTestTools/gotestfmt-action@v2
- with:
- token: ${{ secrets.GITHUB_TOKEN }}
-
- name: Build
run: |
rm -rf crypto/libolm
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 578349c9..9a9e7375 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -17,7 +17,7 @@ jobs:
lock-stale:
runs-on: ubuntu-latest
steps:
- - uses: dessant/lock-threads@v5
+ - uses: dessant/lock-threads@v6
id: lock
with:
issue-inactive-days: 90
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 81701203..616fccb2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v5.0.0
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@@ -9,7 +9,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/tekwizely/pre-commit-golang
- rev: v1.0.0-rc.1
+ rev: v1.0.0-rc.4
hooks:
- id: go-imports-repo
args:
@@ -18,8 +18,7 @@ repos:
- "-w"
- id: go-vet-repo-mod
- id: go-mod-tidy
- # TODO enable this
- #- id: go-staticcheck-repo-mod
+ - id: go-staticcheck-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.4.2
@@ -27,3 +26,4 @@ repos:
- id: prevent-literal-http-methods
- id: zerolog-ban-global-log
- id: zerolog-ban-msgf
+ - id: zerolog-use-stringer
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8e71381e..f2829199 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,237 @@
+## v0.26.3 (2026-02-16)
+
+* Bumped minimum Go version to 1.25.
+* *(client)* Added fields for sending [MSC4354] sticky events.
+* *(bridgev2)* Added automatic message request accepting when sending message.
+* *(mediaproxy)* Added support for federation thumbnail endpoint.
+* *(crypto/ssss)* Improved support for recovery keys with slightly broken
+ metadata.
+* *(crypto)* Changed key import to call session received callback even for
+ sessions that already exist in the database.
+* *(appservice)* Fixed building websocket URL accidentally using file path
+ separators instead of always `/`.
+* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field.
+* *(client)* Fixed incorrect context usage in async uploads.
+* *(crypto)* Fixed panic when passing invalid input to megolm message index
+ parser used for debugging.
+* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned
+ up properly.
+
+[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354
+
+## v0.26.2 (2026-01-16)
+
+* *(bridgev2)* Added chunked portal deletion to avoid database locks when
+ deleting large portals.
+* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata
+ as per [MSC4392].
+* *(bridgev2/login)* Added `default_value` for user input fields.
+* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested
+ HTTP client settings and to reset active connections of the network connector.
+* *(bridgev2)* Added interface to let network connectors get the provisioning
+ API HTTP router and add new endpoints.
+* *(event)* Added blurhash field to Beeper link preview objects.
+* *(event)* Added [MSC4391] support for bot commands.
+* *(event)* Dropped [MSC4332] support for bot commands.
+* *(client)* Changed media download methods to return an error if the provided
+ MXC URI is empty.
+* *(client)* Stabilized support for [MSC4323].
+* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events.
+* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with
+ a portal re-ID call.
+
+[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391
+[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392
+
+## v0.26.1 (2025-12-16)
+
+* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return
+ the mime type from the callback rather than in the return get media return
+ value. The callback can now also redirect the caller to a different file.
+* *(federation)* Added join/knock/leave functions
+ (thanks to [@nexy7574] in [#422]).
+* *(federation/eventauth)* Fixed various incorrect checks.
+* *(client)* Added backoff for retrying media uploads to external URLs
+ (with MSC3870).
+* *(bridgev2/config)* Added support for overriding config fields using
+ environment variables.
+* *(bridgev2/commands)* Added command to mute chat on remote network.
+* *(bridgev2)* Added interface for network connectors to redirect to a different
+ user ID when handling an invite from Matrix.
+* *(bridgev2)* Added interface for signaling message request status of portals.
+* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag
+ is set in chat info.
+* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if
+ bridging the new one is successful.
+* *(bridgev2/mxmain)* Improved error message when trying to run bridge with
+ pre-megabridge database when no database migration exists.
+* *(bridgev2)* Improved reliability of database migration when enabling split
+ portals.
+* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats.
+* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public
+ portal rooms.
+* *(bridgev2)* Stopped hardcoding room versions in favor of checking
+ server capabilities to determine appropriate `/createRoom` parameters.
+
+[#422]: https://github.com/mautrix/go/pull/422
+
+## v0.26.0 (2025-11-16)
+
+* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent`
+ has been able to do the same for a while now.
+* *(client,federation)* Added size limits for responses to make it safer to send
+ requests to untrusted servers.
+* *(client)* Added wrapper for `/admin/whois` client API
+ (thanks to [@nexy7574] in [#411]).
+* *(synapseadmin)* Added `force_purge` option to DeleteRoom
+ (thanks to [@nexy7574] in [#420]).
+* *(statestore)* Added saving join rules for rooms.
+* *(bridgev2)* Added optional automatic rollback of room state if bridging the
+ change to the remote network fails.
+* *(bridgev2)* Added management room notices if transient disconnect state
+ doesn't resolve within 3 minutes.
+* *(bridgev2)* Added interface to signal that certain participants couldn't be
+ invited when creating a group.
+* *(bridgev2)* Added `select` type for user input fields in login.
+* *(bridgev2)* Added interface to let network connector customize personal
+ filtering space.
+* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to
+ other bots.
+* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever
+ possible.
+* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names,
+ and encrypted files.
+* *(bridgev2/commands)* Added command to resync a single portal.
+* *(bridgev2/commands)* Added create group command.
+* *(bridgev2/config)* Added option to limit maximum number of logins.
+* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room
+ is public.
+* *(bridgev2/disappear)* Changed read receipt handling to only start
+ disappearing timers for messages up to the read message (note: may not work in
+ all cases if the read receipt points at an unknown event).
+* *(event/reply)* Changed plaintext reply fallback removal to only happen when
+ an HTML reply fallback is removed successfully.
+* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run.
+* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages.
+* *(federation)* Fixed HTTP method for sending transactions
+ (thanks to [@nexy7574] in [#426]).
+* *(federation)* Fixed response body being closed even when using `DontReadBody`
+ parameter.
+* *(federation)* Fixed validating auth for requests with query params.
+* *(federation/eventauth)* Fixed typo causing restricted joins to not work.
+
+[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169
+[#411]: github.com/mautrix/go/pull/411
+[#420]: github.com/mautrix/go/pull/420
+[#426]: github.com/mautrix/go/pull/426
+
+## v0.25.2 (2025-10-16)
+
+* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into
+ `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old
+ behavior, but most users likely want the relaxed version, as there are real
+ users whose user IDs aren't valid under the strict rules.
+* *(crypto)* Added helper methods for generating and verifying with recovery
+ keys.
+* *(bridgev2/matrix)* Added config option to automatically generate a recovery
+ key for the bridge bot and self-sign the bridge's device.
+* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode
+ for encryption with standard servers like Synapse.
+* *(bridgev2)* Added optional support for implicit read receipts.
+* *(bridgev2)* Added interface for deleting chats on remote network.
+* *(bridgev2)* Added local enforcement of media duration and size limits.
+* *(bridgev2)* Extended event duration logging to log any event taking too long.
+* *(bridgev2)* Improved validation in group creation provisioning API.
+* *(event)* Added event type constant for poll end events.
+* *(client)* Added wrapper for searching user directory.
+* *(client)* Improved support for managing [MSC4140] delayed events.
+* *(crypto/helper)* Changed default sync handling to not block on waiting for
+ decryption keys. On initial sync, keys won't be requested at all by default.
+* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1).
+* *(bridgev2)* Fixed various bugs with migrating to split portals.
+* *(event)* Fixed poll start events having incorrect null `m.relates_to`.
+* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling.
+* *(federation)* Fixed various bugs in event auth.
+
+## v0.25.1 (2025-09-16)
+
+* *(client)* Fixed HTTP method of delete devices API call
+ (thanks to [@fmseals] in [#393]).
+* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints
+ (thanks to [@nexy7574] in [#407]).
+* *(client)* Stabilized support for extensible profiles.
+* *(client)* Stabilized support for `state_after` in sync.
+* *(client)* Removed deprecated MSC2716 requests.
+* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if
+ the content struct doesn't implement `Relatable`.
+* *(crypto)* Changed olm unwedging to ignore newly created sessions if they
+ haven't been used successfully in either direction.
+* *(federation)* Added utilities for generating, parsing, validating and
+ authorizing PDUs.
+ * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2`
+* *(event)* Added `is_animated` flag from [MSC4230] to file info.
+* *(event)* Added types for [MSC4332]: In-room bot commands.
+* *(event)* Added missing poll end event type for [MSC3381].
+* *(appservice)* Fixed URLs not being escaped properly when using unix socket
+ for homeserver connections.
+* *(format)* Added more helpers for forming markdown links.
+* *(event,bridgev2)* Added support for Beeper's disappearing message state event.
+* *(bridgev2)* Redesigned group creation interface and added support in commands
+ and provisioning API.
+* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to
+ get an old event. The method is best effort only, as some configurations don't
+ allow fetching old events.
+* *(bridgev2)* Added shared logic for provisioning that can be reused by the
+ API, commands and other sources.
+* *(bridgev2)* Fixed mentions and URL previews not being copied over when
+ caption and media are merged.
+* *(bridgev2)* Removed config option to change provisioning API prefix, which
+ had already broken in the previous release.
+
+[@fmseals]: https://github.com/fmseals
+[#393]: https://github.com/mautrix/go/pull/393
+[#407]: https://github.com/mautrix/go/pull/407
+[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
+[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230
+[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332
+
+## v0.25.0 (2025-08-16)
+
+* Bumped minimum Go version to 1.24.
+* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux
+ with standard library ServeMux.
+* *(client,bridgev2)* Added support for creator power in room v12.
+* *(client)* Added option to not set `User-Agent` header for improved Wasm
+ compatibility.
+* *(bridgev2)* Added support for following tombstones.
+* *(bridgev2)* Added interface for getting arbitrary state event from Matrix.
+* *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't
+ use too many resources even if there are a large number of messages.
+* *(bridgev2/commands)* Added support for canceling QR login with `cancel`
+ command.
+* *(client)* Added option to override HTTP client used for .well-known
+ resolution.
+* *(crypto/backup)* Added method for encrypting key backup session without
+ private keys.
+* *(event->id)* Moved room version type and constants to id package.
+* *(bridgev2)* Bots in DM portals will now be added to the functional members
+ state event to hide them from the room name calculation.
+* *(bridgev2)* Changed message delete handling to ignore "delete for me" events
+ if there are multiple Matrix users in the room.
+* *(format/htmlparser)* Changed text processing to collapse multiple spaces into
+ one when outside `pre`/`code` tags.
+* *(format/htmlparser)* Removed link suffix in plaintext output when link text
+ is only missing protocol part of href.
+ * e.g. `example.com` will turn into
+ `example.com` rather than `example.com (https://example.com)`
+* *(appservice)* Switched appservice websockets from gorilla/websocket to
+ coder/websocket.
+* *(bridgev2/matrix)* Fixed encryption key sharing not ignoring ghosts properly.
+* *(crypto/attachments)* Fixed hash check when decrypting file streams.
+* *(crypto)* Removed unnecessary `AlreadyShared` error in `ShareGroupSession`.
+ The function will now act as if it was successful instead.
+
## v0.24.2 (2025-07-16)
* *(bridgev2)* Added support for return values from portal event handlers. Note
@@ -203,6 +437,7 @@
[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156
[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190
[#288]: https://github.com/mautrix/go/pull/288
+[@onestacked]: https://github.com/onestacked
## v0.22.0 (2024-11-16)
diff --git a/README.md b/README.md
index ac41ca78..b1a2edf8 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,9 @@
# mautrix-go
[](https://pkg.go.dev/maunium.net/go/mautrix)
-A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks),
-[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp)
+A Golang Matrix framework. Used by [gomuks](https://gomuks.app),
+[go-neb](https://github.com/matrix-org/go-neb),
+[mautrix-whatsapp](https://github.com/mautrix/whatsapp)
and others.
Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net)
@@ -13,9 +14,10 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or
In addition to the basic client API features the original project has, this framework also has:
* Appservice support (Intent API like mautrix-python, room state storage, etc)
-* End-to-end encryption support (incl. interactive SAS verification)
+* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc)
* High-level module for building puppeting bridges
-* High-level module for building chat clients
+* Partial federation module (making requests, PDU processing and event authorization)
+* A media proxy server which can be used to expose anything as a Matrix media repo
* Wrapper functions for the Synapse admin API
* Structs for parsing event content
* Helpers for parsing and generating Matrix HTML
diff --git a/appservice/appservice.go b/appservice/appservice.go
index 518e1073..d7037ef6 100644
--- a/appservice/appservice.go
+++ b/appservice/appservice.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2023 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -19,8 +19,7 @@ import (
"syscall"
"time"
- "github.com/gorilla/mux"
- "github.com/gorilla/websocket"
+ "github.com/coder/websocket"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
"gopkg.in/yaml.v3"
@@ -43,7 +42,7 @@ func Create() *AppService {
intents: make(map[id.UserID]*IntentAPI),
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
StateStore: mautrix.NewMemoryStateStore().(StateStore),
- Router: mux.NewRouter(),
+ Router: http.NewServeMux(),
UserAgent: mautrix.DefaultUserAgent,
txnIDC: NewTransactionIDCache(128),
Live: true,
@@ -61,12 +60,12 @@ func Create() *AppService {
DefaultHTTPRetries: 4,
}
- as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
- as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
- as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
- as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
- as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
- as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
+ as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
+ as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
+ as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
+ as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
+ as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
+ as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
return as
}
@@ -114,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil)
// QueryHandler handles room alias and user ID queries from the homeserver.
type QueryHandler interface {
- QueryAlias(alias string) bool
+ QueryAlias(alias id.RoomAlias) bool
QueryUser(userID id.UserID) bool
}
type QueryHandlerStub struct{}
-func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
+func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool {
return false
}
@@ -128,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
return false
}
-type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
+type WebsocketHandler func(WebsocketCommand) (ok bool, data any)
type StateStore interface {
mautrix.StateStore
@@ -160,7 +159,7 @@ type AppService struct {
QueryHandler QueryHandler
StateStore StateStore
- Router *mux.Router
+ Router *http.ServeMux
UserAgent string
server *http.Server
HTTPClient *http.Client
@@ -179,7 +178,6 @@ type AppService struct {
intentsLock sync.RWMutex
ws *websocket.Conn
- wsWriteLock sync.Mutex
StopWebsocket func(error)
websocketHandlers map[string]WebsocketHandler
websocketHandlersLock sync.RWMutex
@@ -336,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error {
} else if as.hsURLForClient.Scheme == "" {
as.hsURLForClient.Scheme = "https"
}
- as.hsURLForClient.RawPath = parsedURL.EscapedPath()
+ as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath()
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar}
@@ -362,7 +360,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
AccessToken: as.Registration.AppToken,
UserAgent: as.UserAgent,
StateStore: as.StateStore,
- Log: as.Log.With().Str("as_user_id", userID.String()).Logger(),
+ Log: as.Log.With().Stringer("as_user_id", userID).Logger(),
Client: as.HTTPClient,
DefaultHTTPRetries: as.DefaultHTTPRetries,
SpecVersions: as.SpecVersions,
diff --git a/appservice/http.go b/appservice/http.go
index 1ebe6e56..27ce6288 100644
--- a/appservice/http.go
+++ b/appservice/http.go
@@ -17,7 +17,6 @@ import (
"syscall"
"time"
- "github.com/gorilla/mux"
"github.com/rs/zerolog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exstrings"
@@ -95,8 +94,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
return
}
- vars := mux.Vars(r)
- txnID := vars["txnID"]
+ txnID := r.PathValue("txnID")
if len(txnID) == 0 {
mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w)
return
@@ -203,7 +201,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
}
err := evt.Content.ParseRaw(evt.Type)
if errors.Is(err, event.ErrUnsupportedContentType) {
- log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event")
+ log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event")
} else if err != nil {
log.Warn().Err(err).
Str("event_id", evt.ID.String()).
@@ -240,8 +238,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
return
}
- vars := mux.Vars(r)
- roomAlias := vars["roomAlias"]
+ roomAlias := id.RoomAlias(r.PathValue("roomAlias"))
ok := as.QueryHandler.QueryAlias(roomAlias)
if ok {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
@@ -256,8 +253,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
return
}
- vars := mux.Vars(r)
- userID := id.UserID(vars["userID"])
+ userID := id.UserID(r.PathValue("userID"))
ok := as.QueryHandler.QueryUser(userID)
if ok {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
diff --git a/appservice/intent.go b/appservice/intent.go
index d6cda137..5d43f190 100644
--- a/appservice/intent.go
+++ b/appservice/intent.go
@@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
}
func (intent *IntentAPI) Register(ctx context.Context) error {
- _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{
+ _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{
Username: intent.Localpart,
Type: mautrix.AuthTypeAppservice,
InhibitLogin: true,
@@ -86,6 +86,7 @@ func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error {
type EnsureJoinedParams struct {
IgnoreCache bool
BotOverride *mautrix.Client
+ Via []string
}
func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error {
@@ -99,11 +100,17 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
return nil
}
- if err := intent.EnsureRegistered(ctx); err != nil {
+ err := intent.EnsureRegistered(ctx)
+ if err != nil {
return fmt.Errorf("failed to ensure joined: %w", err)
}
- resp, err := intent.JoinRoomByID(ctx, roomID)
+ var resp *mautrix.RespJoinRoom
+ if len(params.Via) > 0 {
+ resp, err = intent.JoinRoom(ctx, roomID.String(), &mautrix.ReqJoinRoom{Via: params.Via})
+ } else {
+ resp, err = intent.JoinRoomByID(ctx, roomID)
+ }
if err != nil {
bot := intent.bot
if params.BotOverride != nil {
@@ -207,23 +214,31 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any {
}
}
-func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
+func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
- return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON)
+ return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...)
}
-func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
+func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
- contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
- return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
+ if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
+ return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
+ }
+ contentJSON = intent.AddDoublePuppetValue(contentJSON)
+ return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...)
}
-func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
+// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead
+func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
+ return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
+}
+
+func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
if eventType != event.StateMember || stateKey != string(intent.UserID) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
@@ -232,15 +247,12 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
- return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON)
+ return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...)
}
+// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead
func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
- if err := intent.EnsureJoined(ctx, roomID); err != nil {
- return nil, err
- }
- contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
- return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts)
+ return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
}
func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
@@ -299,7 +311,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
- return &mautrix.RespJoinRoom{}, err
+ return &mautrix.RespJoinRoom{RoomID: roomID}, err
}
return intent.Client.JoinRoomByID(ctx, roomID)
}
@@ -368,6 +380,24 @@ func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id
return member
}
+func (intent *IntentAPI) FillPowerLevelCreateEvent(ctx context.Context, roomID id.RoomID, pl *event.PowerLevelsEventContent) error {
+ if pl.CreateEvent != nil {
+ return nil
+ }
+ var err error
+ pl.CreateEvent, err = intent.StateStore.GetCreate(ctx, roomID)
+ if err != nil {
+ return fmt.Errorf("failed to get create event from cache: %w", err)
+ } else if pl.CreateEvent != nil {
+ return nil
+ }
+ pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "")
+ if err != nil {
+ return fmt.Errorf("failed to get create event from server: %w", err)
+ }
+ return nil
+}
+
func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID)
if err != nil {
@@ -377,6 +407,12 @@ func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl
if pl == nil {
pl = &event.PowerLevelsEventContent{}
err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl)
+ if err != nil {
+ return
+ }
+ }
+ if pl.CreateEvent == nil {
+ pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "")
}
return
}
@@ -391,8 +427,7 @@ func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, us
return nil, err
}
- if pl.GetUserLevel(userID) != level {
- pl.SetUserLevel(userID, level)
+ if pl.EnsureUserLevelAs(intent.UserID, userID, level) {
return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl)
}
return nil, nil
@@ -481,7 +516,7 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU
// No need to update
return nil
}
- if !avatarURL.IsEmpty() {
+ if !avatarURL.IsEmpty() && !intent.SpecVersions.Supports(mautrix.BeeperFeatureHungry) {
// Some homeservers require the avatar to be downloaded before setting it
resp, _ := intent.Download(ctx, avatarURL)
if resp != nil {
diff --git a/appservice/websocket.go b/appservice/websocket.go
index 3d5bd232..ef65e65a 100644
--- a/appservice/websocket.go
+++ b/appservice/websocket.go
@@ -11,15 +11,15 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"net/http"
"net/url"
- "path/filepath"
+ "path"
"strings"
"sync"
"sync/atomic"
- "time"
- "github.com/gorilla/websocket"
+ "github.com/coder/websocket"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -28,11 +28,9 @@ import (
)
type WebsocketRequest struct {
- ReqID int `json:"id,omitempty"`
- Command string `json:"command"`
- Data interface{} `json:"data"`
-
- Deadline time.Duration `json:"-"`
+ ReqID int `json:"id,omitempty"`
+ Command string `json:"command"`
+ Data any `json:"data"`
}
type WebsocketCommand struct {
@@ -43,7 +41,7 @@ type WebsocketCommand struct {
Ctx context.Context `json:"-"`
}
-func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
+func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
return nil
}
@@ -58,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketR
var prefixMessage string
for unwrappedErr != nil {
errorData, jsonErr = json.Marshal(unwrappedErr)
- if errorData != nil && len(errorData) > 2 && jsonErr == nil {
+ if len(errorData) > 2 && jsonErr == nil {
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
prefixMessage = strings.TrimRight(prefixMessage, ": ")
break
@@ -100,8 +98,8 @@ type WebsocketMessage struct {
}
const (
- WebsocketCloseConnReplaced = 4001
- WebsocketCloseTxnNotAcknowledged = 4002
+ WebsocketCloseConnReplaced websocket.StatusCode = 4001
+ WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002
)
type MeowWebsocketCloseCode string
@@ -135,7 +133,7 @@ func (mwcc MeowWebsocketCloseCode) String() string {
}
type CloseCommand struct {
- Code int `json:"-"`
+ Code websocket.StatusCode `json:"-"`
Command string `json:"command"`
Status MeowWebsocketCloseCode `json:"status"`
}
@@ -145,15 +143,15 @@ func (cc CloseCommand) Error() string {
}
func parseCloseError(err error) error {
- closeError := &websocket.CloseError{}
+ var closeError websocket.CloseError
if !errors.As(err, &closeError) {
return err
}
var closeCommand CloseCommand
closeCommand.Code = closeError.Code
closeCommand.Command = "disconnect"
- if len(closeError.Text) > 0 {
- jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
+ if len(closeError.Reason) > 0 {
+ jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand)
if jsonErr != nil {
return err
}
@@ -161,7 +159,7 @@ func parseCloseError(err error) error {
if len(closeCommand.Status) == 0 {
if closeCommand.Code == WebsocketCloseConnReplaced {
closeCommand.Status = MeowConnectionReplaced
- } else if closeCommand.Code == websocket.CloseServiceRestart {
+ } else if closeCommand.Code == websocket.StatusServiceRestart {
closeCommand.Status = MeowServerShuttingDown
}
}
@@ -172,20 +170,23 @@ func (as *AppService) HasWebsocket() bool {
return as.ws != nil
}
-func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
+func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error {
ws := as.ws
if cmd == nil {
return nil
} else if ws == nil {
return ErrWebsocketNotConnected
}
- as.wsWriteLock.Lock()
- defer as.wsWriteLock.Unlock()
- if cmd.Deadline == 0 {
- cmd.Deadline = 3 * time.Minute
+ wr, err := ws.Writer(ctx, websocket.MessageText)
+ if err != nil {
+ return err
}
- _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
- return ws.WriteJSON(cmd)
+ err = json.NewEncoder(wr).Encode(cmd)
+ if err != nil {
+ _ = wr.Close()
+ return err
+ }
+ return wr.Close()
}
func (as *AppService) clearWebsocketResponseWaiters() {
@@ -222,12 +223,12 @@ func (er *ErrorResponse) Error() string {
return fmt.Sprintf("%s: %s", er.Code, er.Message)
}
-func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
+func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error {
cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
respChan := make(chan *WebsocketCommand, 1)
as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
- err := as.SendWebsocket(cmd)
+ err := as.SendWebsocket(ctx, cmd)
if err != nil {
return err
}
@@ -256,7 +257,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques
}
}
-func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
+func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) {
zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command")
return false, fmt.Errorf("unknown request type")
}
@@ -280,14 +281,28 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg
return true, &WebsocketTransactionResponse{TxnID: msg.TxnID}
}
-func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
+func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) {
defer stopFunc(ErrWebsocketUnknownError)
- ctx := context.Background()
for {
- var msg WebsocketMessage
- err := ws.ReadJSON(&msg)
+ msgType, reader, err := ws.Reader(ctx)
if err != nil {
- as.Log.Debug().Err(err).Msg("Error reading from websocket")
+ as.Log.Debug().Err(err).Msg("Error getting reader from websocket")
+ stopFunc(parseCloseError(err))
+ return
+ } else if msgType != websocket.MessageText {
+ as.Log.Debug().Msg("Ignoring non-text message from websocket")
+ continue
+ }
+ data, err := io.ReadAll(reader)
+ if err != nil {
+ as.Log.Debug().Err(err).Msg("Error reading data from websocket")
+ stopFunc(parseCloseError(err))
+ return
+ }
+ var msg WebsocketMessage
+ err = json.Unmarshal(data, &msg)
+ if err != nil {
+ as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket")
stopFunc(parseCloseError(err))
return
}
@@ -298,11 +313,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
with = with.Str("transaction_id", msg.TxnID)
}
log := with.Logger()
- ctx = log.WithContext(ctx)
+ ctx := log.WithContext(ctx)
if msg.Command == "" || msg.Command == "transaction" {
ok, resp := as.WebsocketTransactionHandler(ctx, msg)
go func() {
- err := as.SendWebsocket(msg.MakeResponse(ok, resp))
+ err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp))
if err != nil {
log.Warn().Err(err).Msg("Failed to send response to websocket transaction")
} else {
@@ -334,7 +349,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
}
go func() {
okResp, data := handler(msg.WebsocketCommand)
- err := as.SendWebsocket(msg.MakeResponse(okResp, data))
+ err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data))
if err != nil {
log.Error().Err(err).Msg("Failed to send response to websocket command")
} else if okResp {
@@ -347,7 +362,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn)
}
}
-func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
+func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error {
var parsed *url.URL
if baseURL != "" {
var err error
@@ -359,18 +374,21 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
copiedURL := *as.hsURLForClient
parsed = &copiedURL
}
- parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
+ parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
if parsed.Scheme == "http" {
parsed.Scheme = "ws"
} else if parsed.Scheme == "https" {
parsed.Scheme = "wss"
}
- ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
- "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
- "User-Agent": []string{as.BotClient().UserAgent},
+ ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{
+ HTTPClient: as.HTTPClient,
+ HTTPHeader: http.Header{
+ "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
+ "User-Agent": []string{as.BotClient().UserAgent},
- "X-Mautrix-Process-ID": []string{as.ProcessID},
- "X-Mautrix-Websocket-Version": []string{"3"},
+ "X-Mautrix-Process-ID": []string{as.ProcessID},
+ "X-Mautrix-Websocket-Version": []string{"3"},
+ },
})
if resp != nil && resp.StatusCode >= 400 {
var errResp mautrix.RespError
@@ -401,12 +419,13 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
}
})
}
+ ws.SetReadLimit(50 * 1024 * 1024)
as.ws = ws
as.StopWebsocket = stopFunc
as.PrepareWebsocket()
as.Log.Debug().Msg("Appservice transaction websocket opened")
- go as.consumeWebsocket(stopFunc, ws)
+ go as.consumeWebsocket(ctx, stopFunc, ws)
var onConnectDone atomic.Bool
if onConnect != nil {
@@ -428,12 +447,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
as.ws = nil
}
- _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
- err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
- if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
- as.Log.Warn().Err(err).Msg("Error writing close message to websocket")
- }
- err = ws.Close()
+ err = ws.Close(websocket.StatusGoingAway, "")
if err != nil {
as.Log.Warn().Err(err).Msg("Error closing websocket")
}
diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go
index a4ce033e..226adc90 100644
--- a/bridgev2/bridge.go
+++ b/bridgev2/bridge.go
@@ -9,11 +9,14 @@ package bridgev2
import (
"context"
"fmt"
+ "os"
"sync"
+ "sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exhttp"
"go.mau.fi/util/exsync"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
@@ -51,6 +54,7 @@ type Bridge struct {
Background bool
ExternallyManagedDB bool
+ stopping atomic.Bool
wakeupBackfillQueue chan struct{}
stopBackfillQueue *exsync.Event
@@ -120,12 +124,13 @@ func (br *Bridge) Start(ctx context.Context) error {
if err != nil {
return err
}
- br.PostStart(ctx)
+ go br.PostStart(ctx)
return nil
}
func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error {
br.Background = true
+ br.stopping.Store(false)
err := br.StartConnectors(ctx)
if err != nil {
return err
@@ -161,6 +166,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
case <-time.After(20 * time.Second):
case <-ctx.Done():
}
+ br.stopping.Store(true)
return nil
} else {
br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode")
@@ -170,6 +176,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
func (br *Bridge) StartConnectors(ctx context.Context) error {
br.Log.Info().Msg("Starting bridge")
+ br.stopping.Store(false)
if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil {
br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background())
br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx)
@@ -182,7 +189,11 @@ func (br *Bridge) StartConnectors(ctx context.Context) error {
}
}
if !br.Background {
- br.didSplitPortals = br.MigrateToSplitPortals(ctx)
+ var postMigrate func()
+ br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx)
+ if postMigrate != nil {
+ defer postMigrate()
+ }
}
br.Log.Info().Msg("Starting Matrix connector")
err := br.Matrix.Start(ctx)
@@ -271,20 +282,64 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps b
Msg("Resent bridge info to all portals")
}
-func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool {
+func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) {
log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger()
ctx = log.WithContext(ctx)
if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" {
- return false
+ return false, nil
}
affected, err := br.DB.Portal.MigrateToSplitPortals(ctx)
if err != nil {
- log.Err(err).Msg("Failed to migrate portals")
- return false
+ log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals")
+ os.Exit(31)
+ return false, nil
}
log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals")
+ affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx)
+ if err != nil {
+ log.Err(err).Msg("Failed to fix parent portals after split portal migration")
+ os.Exit(31)
+ return false, nil
+ }
+ log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration")
+ withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx)
+ if err != nil {
+ log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate")
+ os.Exit(31)
+ return false, nil
+ }
+ var roomsToDelete []id.RoomID
+ log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver")
+ for _, portal := range withoutReceiver {
+ if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil {
+ log.Err(err).
+ Str("portal_id", string(portal.ID)).
+ Stringer("mxid", portal.MXID).
+ Msg("Failed to delete portal database row that failed to migrate")
+ } else if portal.MXID != "" {
+ log.Debug().
+ Str("portal_id", string(portal.ID)).
+ Stringer("mxid", portal.MXID).
+ Msg("Marked portal room for deletion from homeserver")
+ roomsToDelete = append(roomsToDelete, portal.MXID)
+ } else {
+ log.Debug().
+ Str("portal_id", string(portal.ID)).
+ Msg("Deleted portal row with no Matrix room")
+ }
+ }
br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true")
- return affected > 0
+ log.Info().Msg("Finished split portal migration successfully")
+ return affected > 0, func() {
+ for _, roomID := range roomsToDelete {
+ if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil {
+ log.Err(err).
+ Stringer("mxid", roomID).
+ Msg("Failed to delete portal room that failed to migrate")
+ }
+ }
+ log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate")
+ }
}
func (br *Bridge) StartLogins(ctx context.Context) error {
@@ -319,6 +374,46 @@ func (br *Bridge) StartLogins(ctx context.Context) error {
return nil
}
+func (br *Bridge) ResetNetworkConnections() {
+ nrn, ok := br.Network.(NetworkResettingNetwork)
+ if ok {
+ br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections")
+ nrn.ResetNetworkConnections()
+ return
+ }
+
+ br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually")
+ for _, login := range br.GetAllCachedUserLogins() {
+ login.Log.Debug().Msg("Disconnecting and recreating client for network reset")
+ ctx := login.Log.WithContext(br.BackgroundCtx)
+ login.Client.Disconnect()
+ err := login.recreateClient(ctx)
+ if err != nil {
+ login.Log.Err(err).Msg("Failed to recreate client during network reset")
+ login.BridgeState.Send(status.BridgeState{
+ StateEvent: status.StateUnknownError,
+ Error: "bridgev2-network-reset-fail",
+ Info: map[string]any{"go_error": err.Error()},
+ })
+ } else {
+ login.Client.Connect(ctx)
+ }
+ }
+ br.Log.Info().Msg("Finished resetting all user logins")
+}
+
+func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings {
+ mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings)
+ if ok {
+ return mchs.GetHTTPClientSettings()
+ }
+ return exhttp.SensibleClientSettings
+}
+
+func (br *Bridge) IsStopping() bool {
+ return br.stopping.Load()
+}
+
func (br *Bridge) Stop() {
br.stop(false, 0)
}
@@ -329,6 +424,7 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) {
func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) {
br.Log.Info().Msg("Shutting down bridge")
+ br.stopping.Store(true)
br.DisappearLoop.Stop()
br.stopBackfillQueue.Set()
br.Matrix.PreStop()
diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go
index 53282e41..eedae1e8 100644
--- a/bridgev2/bridgeconfig/backfill.go
+++ b/bridgev2/bridgeconfig/backfill.go
@@ -34,10 +34,12 @@ type BackfillQueueConfig struct {
MaxBatchesOverride map[string]int `yaml:"max_batches_override"`
}
-func (bqc *BackfillQueueConfig) GetOverride(name string) int {
- override, ok := bqc.MaxBatchesOverride[name]
- if !ok {
- return bqc.MaxBatches
+func (bqc *BackfillQueueConfig) GetOverride(names ...string) int {
+ for _, name := range names {
+ override, ok := bqc.MaxBatchesOverride[name]
+ if ok {
+ return override
+ }
}
- return override
+ return bqc.MaxBatches
}
diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go
index 9bdee5fe..bd6b9c06 100644
--- a/bridgev2/bridgeconfig/config.go
+++ b/bridgev2/bridgeconfig/config.go
@@ -33,6 +33,8 @@ type Config struct {
Encryption EncryptionConfig `yaml:"encryption"`
Logging zeroconfig.Config `yaml:"logging"`
+ EnvConfigPrefix string `yaml:"env_config_prefix"`
+
ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"`
}
@@ -60,36 +62,40 @@ type CleanupOnLogouts struct {
}
type BridgeConfig struct {
- CommandPrefix string `yaml:"command_prefix"`
- PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
- PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
- AsyncEvents bool `yaml:"async_events"`
- SplitPortals bool `yaml:"split_portals"`
- ResendBridgeInfo bool `yaml:"resend_bridge_info"`
- NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
- BridgeStatusNotices string `yaml:"bridge_status_notices"`
- UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
- BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
- BridgeNotices bool `yaml:"bridge_notices"`
- TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
- OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
- MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
- DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
- CrossRoomReplies bool `yaml:"cross_room_replies"`
- OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
- CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
- Relay RelayConfig `yaml:"relay"`
- Permissions PermissionConfig `yaml:"permissions"`
- Backfill BackfillConfig `yaml:"backfill"`
+ CommandPrefix string `yaml:"command_prefix"`
+ PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
+ PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
+ AsyncEvents bool `yaml:"async_events"`
+ SplitPortals bool `yaml:"split_portals"`
+ ResendBridgeInfo bool `yaml:"resend_bridge_info"`
+ NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
+ BridgeStatusNotices string `yaml:"bridge_status_notices"`
+ UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
+ UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"`
+ BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
+ BridgeNotices bool `yaml:"bridge_notices"`
+ TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
+ OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
+ MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
+ DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
+ CrossRoomReplies bool `yaml:"cross_room_replies"`
+ OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
+ RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"`
+ KickMatrixUsers bool `yaml:"kick_matrix_users"`
+ CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
+ Relay RelayConfig `yaml:"relay"`
+ Permissions PermissionConfig `yaml:"permissions"`
+ Backfill BackfillConfig `yaml:"backfill"`
}
type MatrixConfig struct {
- MessageStatusEvents bool `yaml:"message_status_events"`
- DeliveryReceipts bool `yaml:"delivery_receipts"`
- MessageErrorNotices bool `yaml:"message_error_notices"`
- SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
- FederateRooms bool `yaml:"federate_rooms"`
- UploadFileThreshold int64 `yaml:"upload_file_threshold"`
+ MessageStatusEvents bool `yaml:"message_status_events"`
+ DeliveryReceipts bool `yaml:"delivery_receipts"`
+ MessageErrorNotices bool `yaml:"message_error_notices"`
+ SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
+ FederateRooms bool `yaml:"federate_rooms"`
+ UploadFileThreshold int64 `yaml:"upload_file_threshold"`
+ GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"`
}
type AnalyticsConfig struct {
@@ -99,7 +105,6 @@ type AnalyticsConfig struct {
}
type ProvisioningConfig struct {
- Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"`
DebugEndpoints bool `yaml:"debug_endpoints"`
EnableSessionTransfers bool `yaml:"enable_session_transfers"`
@@ -112,10 +117,12 @@ type DirectMediaConfig struct {
}
type PublicMediaConfig struct {
- Enabled bool `yaml:"enabled"`
- SigningKey string `yaml:"signing_key"`
- HashLength int `yaml:"hash_length"`
- Expiry int `yaml:"expiry"`
+ Enabled bool `yaml:"enabled"`
+ SigningKey string `yaml:"signing_key"`
+ Expiry int `yaml:"expiry"`
+ HashLength int `yaml:"hash_length"`
+ PathPrefix string `yaml:"path_prefix"`
+ UseDatabase bool `yaml:"use_database"`
}
type DoublePuppetConfig struct {
diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go
index 1ef7e18f..934613ca 100644
--- a/bridgev2/bridgeconfig/encryption.go
+++ b/bridgev2/bridgeconfig/encryption.go
@@ -16,6 +16,8 @@ type EncryptionConfig struct {
Require bool `yaml:"require"`
Appservice bool `yaml:"appservice"`
MSC4190 bool `yaml:"msc4190"`
+ MSC4392 bool `yaml:"msc4392"`
+ SelfSign bool `yaml:"self_sign"`
PlaintextMentions bool `yaml:"plaintext_mentions"`
diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go
index fb2a86d6..954a37c3 100644
--- a/bridgev2/bridgeconfig/legacymigrate.go
+++ b/bridgev2/bridgeconfig/legacymigrate.go
@@ -133,9 +133,7 @@ func doMigrateLegacy(helper up.Helper, python bool) {
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"})
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"})
- CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
- CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"})
diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go
index 610051e0..9efe068e 100644
--- a/bridgev2/bridgeconfig/permissions.go
+++ b/bridgev2/bridgeconfig/permissions.go
@@ -24,6 +24,7 @@ type Permissions struct {
DoublePuppet bool `yaml:"double_puppet"`
Admin bool `yaml:"admin"`
ManageRelay bool `yaml:"manage_relay"`
+ MaxLogins int `yaml:"max_logins"`
}
type PermissionConfig map[string]*Permissions
@@ -40,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool {
_, hasExampleDomain := pc["example.com"]
_, hasExampleUser := pc["@admin:example.com"]
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
- if len(pc) <= exampleLen {
- return false
- }
- return true
+ return len(pc) > exampleLen
}
func (pc PermissionConfig) Get(userID id.UserID) Permissions {
diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go
index b69a1fdb..92515ea0 100644
--- a/bridgev2/bridgeconfig/upgrade.go
+++ b/bridgev2/bridgeconfig/upgrade.go
@@ -33,6 +33,7 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key")
helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices")
helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect")
+ helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects")
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
helper.Copy(up.Bool, "bridge", "bridge_notices")
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
@@ -40,6 +41,8 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "mute_only_on_create")
helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages")
helper.Copy(up.Bool, "bridge", "cross_room_replies")
+ helper.Copy(up.Bool, "bridge", "revert_failed_state_changes")
+ helper.Copy(up.Bool, "bridge", "kick_matrix_users")
helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled")
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private")
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed")
@@ -98,12 +101,12 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
helper.Copy(up.Bool, "matrix", "federate_rooms")
helper.Copy(up.Int, "matrix", "upload_file_threshold")
+ helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info")
helper.Copy(up.Str|up.Null, "analytics", "token")
helper.Copy(up.Str|up.Null, "analytics", "url")
helper.Copy(up.Str|up.Null, "analytics", "user_id")
- helper.Copy(up.Str, "provisioning", "prefix")
if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := random.String(64)
helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret")
@@ -133,6 +136,8 @@ func doUpgrade(helper up.Helper) {
}
helper.Copy(up.Int, "public_media", "expiry")
helper.Copy(up.Int, "public_media", "hash_length")
+ helper.Copy(up.Str|up.Null, "public_media", "path_prefix")
+ helper.Copy(up.Bool, "public_media", "use_database")
helper.Copy(up.Bool, "backfill", "enabled")
helper.Copy(up.Int, "backfill", "max_initial_messages")
@@ -158,6 +163,8 @@ func doUpgrade(helper up.Helper) {
} else {
helper.Copy(up.Bool, "encryption", "msc4190")
}
+ helper.Copy(up.Bool, "encryption", "msc4392")
+ helper.Copy(up.Bool, "encryption", "self_sign")
helper.Copy(up.Bool, "encryption", "allow_key_sharing")
if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" {
helper.Set(up.Str, random.String(64), "encryption", "pickle_key")
@@ -180,6 +187,8 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Int, "encryption", "rotation", "messages")
helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation")
+ helper.Copy(up.Str|up.Null, "env_config_prefix")
+
helper.Copy(up.Map, "logging")
}
@@ -207,6 +216,7 @@ var SpacedBlocks = [][]string{
{"backfill"},
{"double_puppet"},
{"encryption"},
+ {"env_config_prefix"},
{"logging"},
}
diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go
index f31d4e92..96d9fd5c 100644
--- a/bridgev2/bridgestate.go
+++ b/bridgev2/bridgestate.go
@@ -15,12 +15,15 @@ import (
"time"
"github.com/rs/zerolog"
+ "go.mau.fi/util/exfmt"
"maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
)
+var CatchBridgeStateQueuePanics = true
+
type BridgeStateQueue struct {
prevUnsent *status.BridgeState
prevSent *status.BridgeState
@@ -29,8 +32,13 @@ type BridgeStateQueue struct {
bridge *Bridge
login *UserLogin
+ firstTransientDisconnect time.Time
+ cancelScheduledNotice atomic.Pointer[context.CancelFunc]
+
stopChan chan struct{}
stopReconnect atomic.Pointer[context.CancelFunc]
+
+ unknownErrorReconnects int
}
func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) {
@@ -74,31 +82,63 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() {
if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil {
(*cancelFn)()
}
+ if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil {
+ (*cancelFn)()
+ }
}
func (bsq *BridgeStateQueue) loop() {
- defer func() {
- err := recover()
- if err != nil {
- bsq.login.Log.Error().
- Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
- Any(zerolog.ErrorFieldName, err).
- Msg("Panic in bridge state loop")
- }
- }()
+ if CatchBridgeStateQueuePanics {
+ defer func() {
+ err := recover()
+ if err != nil {
+ bsq.login.Log.Error().
+ Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
+ Any(zerolog.ErrorFieldName, err).
+ Msg("Panic in bridge state loop")
+ }
+ }()
+ }
for state := range bsq.ch {
bsq.immediateSendBridgeState(state)
}
}
-func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) {
+func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
+ log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
+ ctx := log.WithContext(bsq.bridge.BackgroundCtx)
+ if !bsq.waitForTransientDisconnectReconnect(ctx) {
+ return
+ }
+ prevUnsent := bsq.GetPrevUnsent()
+ prev := bsq.GetPrev()
+ if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent ||
+ prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect {
+ log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice")
+ return
+ }
+ log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice")
+ bsq.sendNotice(ctx, triggeredBy, true)
+}
+
+func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) {
noticeConfig := bsq.bridge.Config.BridgeStatusNotices
isError := state.StateEvent == status.StateBadCredentials ||
state.StateEvent == status.StateUnknownError ||
- state.UserAction == status.UserActionOpenNative
+ state.UserAction == status.UserActionOpenNative ||
+ (isDelayed && state.StateEvent == status.StateTransientDisconnect)
sendNotice := noticeConfig == "all" || (noticeConfig == "errors" &&
(isError || (bsq.errorSent && state.StateEvent == status.StateConnected)))
+ if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError {
+ bsq.firstTransientDisconnect = time.Time{}
+ }
if !sendNotice {
+ if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect {
+ if bsq.firstTransientDisconnect.IsZero() {
+ bsq.firstTransientDisconnect = time.Now()
+ }
+ go bsq.scheduleNotice(state)
+ }
return
}
managementRoom, err := bsq.login.User.GetManagementRoom(ctx)
@@ -114,6 +154,9 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
if state.Error != "" {
message += fmt.Sprintf(" (`%s`)", state.Error)
}
+ if isDelayed {
+ message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay))
+ }
if state.Message != "" {
message += fmt.Sprintf(": %s", state.Message)
}
@@ -151,8 +194,14 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat
} else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError {
log.Debug().Msg("Not reconnecting as the previous state was not an unknown error")
return
+ } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects {
+ log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached")
+ return
}
- log.Info().Msg("Disconnecting and reconnecting login due to unknown error")
+ bsq.unknownErrorReconnects++
+ log.Info().
+ Int("reconnect_num", bsq.unknownErrorReconnects).
+ Msg("Disconnecting and reconnecting login due to unknown error")
bsq.login.Disconnect()
log.Debug().Msg("Disconnection finished, recreating client and reconnecting")
err := bsq.login.recreateClient(ctx)
@@ -171,14 +220,30 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b
return false
}
reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2))
+ return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect)
+}
+
+const TransientDisconnectNoticeDelay = 3 * time.Minute
+
+func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool {
+ timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay))
+ zerolog.Ctx(ctx).Trace().
+ Stringer("duration", timeUntilSchedule).
+ Msg("Waiting before sending notice about transient disconnect")
+ return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice)
+}
+
+func (bsq *BridgeStateQueue) waitForReconnect(
+ ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc],
+) bool {
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
- if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil {
+ if oldCancel := ptr.Swap(&cancel); oldCancel != nil {
(*oldCancel)()
}
select {
case <-time.After(reconnectIn):
- return bsq.stopReconnect.CompareAndSwap(&cancel, nil)
+ return ptr.CompareAndSwap(&cancel, nil)
case <-cancelCtx.Done():
return false
case <-bsq.stopChan:
@@ -198,7 +263,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState)
}
ctx := bsq.login.Log.WithContext(context.Background())
- bsq.sendNotice(ctx, state)
+ bsq.sendNotice(ctx, state, false)
retryIn := 2
for {
diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go
index 4c93dbd4..1cae98fe 100644
--- a/bridgev2/commands/debug.go
+++ b/bridgev2/commands/debug.go
@@ -7,10 +7,13 @@
package commands
import (
+ "encoding/json"
"strings"
+ "time"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
)
var CommandRegisterPush = &FullHandler{
@@ -59,3 +62,64 @@ var CommandRegisterPush = &FullHandler{
RequiresLogin: true,
NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI],
}
+
+var CommandSendAccountData = &FullHandler{
+ Func: func(ce *Event) {
+ if len(ce.Args) < 2 {
+ ce.Reply("Usage: `$cmdprefix debug-account-data ")
+ return
+ }
+ var content event.Content
+ evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType}
+ ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0]))
+ err := json.Unmarshal([]byte(ce.RawArgs), &content)
+ if err != nil {
+ ce.Reply("Failed to parse JSON: %v", err)
+ return
+ }
+ err = content.ParseRaw(evtType)
+ if err != nil {
+ ce.Reply("Failed to deserialize content: %v", err)
+ return
+ }
+ res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{
+ Sender: ce.User.MXID,
+ Type: evtType,
+ Timestamp: time.Now().UnixMilli(),
+ RoomID: ce.RoomID,
+ Content: content,
+ })
+ ce.Reply("Result: %+v", res)
+ },
+ Name: "debug-account-data",
+ Help: HelpMeta{
+ Section: HelpSectionAdmin,
+ Description: "Send a room account data event to the bridge",
+ Args: "<_type_> <_content_>",
+ },
+ RequiresAdmin: true,
+ RequiresPortal: true,
+ RequiresLogin: true,
+}
+
+var CommandResetNetwork = &FullHandler{
+ Func: func(ce *Event) {
+ if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") {
+ nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork)
+ if ok {
+ nrn.ResetHTTPTransport()
+ } else {
+ ce.Reply("Network connector does not support resetting HTTP transport")
+ }
+ }
+ ce.Bridge.ResetNetworkConnections()
+ ce.React("✅️")
+ },
+ Name: "debug-reset-network",
+ Help: HelpMeta{
+ Section: HelpSectionAdmin,
+ Description: "Reset network connections to the remote network",
+ Args: "[--reset-transport]",
+ },
+ RequiresAdmin: true,
+}
diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go
index 3544998c..96d62d3e 100644
--- a/bridgev2/commands/login.go
+++ b/bridgev2/commands/login.go
@@ -70,6 +70,15 @@ func fnLogin(ce *Event) {
}
ce.Args = ce.Args[1:]
}
+ if reauth == nil && ce.User.HasTooManyLogins() {
+ ce.Reply(
+ "You have reached the maximum number of logins (%d). "+
+ "Please logout from an existing login before creating a new one. "+
+ "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.",
+ ce.User.Permissions.MaxLogins,
+ )
+ return
+ }
flows := ce.Bridge.Network.GetLoginFlows()
var chosenFlowID string
if len(ce.Args) > 0 {
@@ -112,6 +121,7 @@ func fnLogin(ce *Event) {
ce.Reply("Failed to start login: %v", err)
return
}
+ ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process")
nextStep = checkLoginCommandDirectParams(ce, login, nextStep)
if nextStep != nil {
@@ -190,11 +200,14 @@ type userInputLoginCommandState struct {
func (uilcs *userInputLoginCommandState) promptNext(ce *Event) {
field := uilcs.RemainingFields[0]
+ parts := []string{fmt.Sprintf("Please enter your %s", field.Name)}
if field.Description != "" {
- ce.Reply("Please enter your %s\n%s", field.Name, field.Description)
- } else {
- ce.Reply("Please enter your %s", field.Name)
+ parts = append(parts, field.Description)
}
+ if len(field.Options) > 0 {
+ parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `")))
+ }
+ ce.Reply(strings.Join(parts, "\n"))
StoreCommandState(ce.User, &CommandState{
Next: MinimalCommandHandlerFunc(uilcs.submitNext),
Action: "Login",
@@ -239,14 +252,19 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return fmt.Errorf("failed to upload image: %w", err)
}
content := &event.MessageEventContent{
- MsgType: event.MsgImage,
- FileName: "qr.png",
- URL: qrMXC,
- File: qrFile,
-
+ MsgType: event.MsgImage,
+ FileName: "qr.png",
+ URL: qrMXC,
+ File: qrFile,
Body: qr,
Format: event.FormatHTML,
FormattedBody: fmt.Sprintf("%s
", html.EscapeString(qr)),
+ Info: &event.FileInfo{
+ MimeType: "image/png",
+ Width: qrSizePx,
+ Height: qrSizePx,
+ Size: len(qrData),
+ },
}
if *prevEventID != "" {
content.SetEdit(*prevEventID)
@@ -261,6 +279,36 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return nil
}
+func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
+ for _, att := range atts {
+ if att.FileName == "" {
+ return fmt.Errorf("missing attachment filename")
+ }
+ mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
+ if err != nil {
+ return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
+ }
+ content := &event.MessageEventContent{
+ MsgType: att.Type,
+ FileName: att.FileName,
+ URL: mxc,
+ File: file,
+ Info: &event.FileInfo{
+ MimeType: att.Info.MimeType,
+ Width: att.Info.Width,
+ Height: att.Info.Height,
+ Size: att.Info.Size,
+ },
+ Body: att.FileName,
+ }
+ _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+}
+
type contextKey int
const (
@@ -273,6 +321,13 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
prevEvent = new(id.EventID)
ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent)
}
+ cancelCtx, cancelFunc := context.WithCancel(ce.Ctx)
+ defer cancelFunc()
+ StoreCommandState(ce.User, &CommandState{
+ Action: "Login",
+ Cancel: cancelFunc,
+ })
+ defer StoreCommandState(ce.User, nil)
switch step.DisplayAndWaitParams.Type {
case bridgev2.LoginDisplayTypeQR:
err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent)
@@ -292,7 +347,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
login.Cancel()
return
}
- nextStep, err := login.Wait(ce.Ctx)
+ nextStep, err := login.Wait(cancelCtx)
// Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited)
if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) {
_, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
@@ -445,6 +500,7 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
}
func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
+ ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
if step.Instructions != "" {
ce.Reply(step.Instructions)
}
@@ -459,6 +515,10 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte
Override: override,
}).prompt(ce)
case bridgev2.LoginStepTypeUserInput:
+ err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
+ if err != nil {
+ ce.Reply("Failed to send attachments: %v", err)
+ }
(&userInputLoginCommandState{
Login: login.(bridgev2.LoginProcessUserInput),
RemainingFields: step.UserInputParams.Fields,
diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go
index c28e3a32..391c3685 100644
--- a/bridgev2/commands/processor.go
+++ b/bridgev2/commands/processor.go
@@ -41,10 +41,11 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
}
proc.AddHandlers(
CommandHelp, CommandCancel,
- CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
+ CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
+ CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
CommandSetRelay, CommandUnsetRelay,
- CommandResolveIdentifier, CommandStartChat, CommandSearch,
+ CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
CommandSudo, CommandDoIn,
)
return proc
diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go
index af756c87..94c19739 100644
--- a/bridgev2/commands/relay.go
+++ b/bridgev2/commands/relay.go
@@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
}
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
var relay *bridgev2.UserLogin
- if len(ce.Args) == 0 {
+ if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
relay = ce.User.GetDefaultLogin()
isLoggedIn := relay != nil
if onlySetDefaultRelays {
@@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) {
}
}
} else {
- relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ var targetID networkid.UserLoginID
+ if ce.Portal.Receiver != "" {
+ targetID = ce.Portal.Receiver
+ if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
+ ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
+ return
+ }
+ } else {
+ targetID = networkid.UserLoginID(ce.Args[0])
+ }
+ relay = ce.Bridge.GetCachedUserLoginByID(targetID)
if relay == nil {
- ce.Reply("User login with ID `%s` not found", ce.Args[0])
+ ce.Reply("User login with ID `%s` not found", targetID)
return
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
// All good
diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go
index 719d3dd5..c7b05a6e 100644
--- a/bridgev2/commands/startchat.go
+++ b/bridgev2/commands/startchat.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,13 +8,21 @@ package commands
import (
"context"
+ "errors"
"fmt"
"html"
+ "maps"
+ "slices"
"strings"
"time"
+ "github.com/rs/zerolog"
+
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/bridgev2/provisionutil"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
@@ -30,6 +38,35 @@ var CommandResolveIdentifier = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
+var CommandSyncChat = &FullHandler{
+ Func: func(ce *Event) {
+ login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to find login for sync")
+ ce.Reply("Failed to find login: %v", err)
+ return
+ } else if login == nil {
+ ce.Reply("No login found for sync")
+ return
+ }
+ info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get chat info for sync")
+ ce.Reply("Failed to get chat info: %v", err)
+ return
+ }
+ ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{})
+ ce.React("✅️")
+ },
+ Name: "sync-portal",
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Sync the current portal room",
+ },
+ RequiresPortal: true,
+ RequiresLogin: true,
+}
+
var CommandStartChat = &FullHandler{
Func: fnResolveIdentifier,
Name: "start-chat",
@@ -43,9 +80,15 @@ var CommandStartChat = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
-func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
- remainingArgs := ce.Args[1:]
- login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
+ var remainingArgs []string
+ if len(ce.Args) > 1 {
+ remainingArgs = ce.Args[1:]
+ }
+ var login *bridgev2.UserLogin
+ if len(ce.Args) > 0 {
+ login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ }
if login == nil || login.UserMXID != ce.User.MXID {
remainingArgs = ce.Args
login = ce.User.GetDefaultLogin()
@@ -57,24 +100,13 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even
return login, api, remainingArgs
}
-func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string {
- var targetName string
- var targetMXID id.UserID
- if resp.Ghost != nil {
- if resp.UserInfo != nil {
- resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
- }
- targetName = resp.Ghost.Name
- targetMXID = resp.Ghost.Intent.GetMXID()
- } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
- targetName = *resp.UserInfo.Name
- }
- if targetMXID != "" {
- return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL())
- } else if targetName != "" {
- return fmt.Sprintf("`%s` / %s", resp.UserID, targetName)
+func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string {
+ if resp.MXID != "" {
+ return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL())
+ } else if resp.Name != "" {
+ return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name)
} else {
- return fmt.Sprintf("`%s`", resp.UserID)
+ return fmt.Sprintf("`%s`", resp.ID)
}
}
@@ -87,65 +119,137 @@ func fnResolveIdentifier(ce *Event) {
if api == nil {
return
}
+ allLogins := ce.User.GetUserLogins()
createChat := ce.Command == "start-chat" || ce.Command == "pm"
identifier := strings.Join(identifierParts, " ")
- resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat)
+ resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat)
+ for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ {
+ resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat)
+ }
if err != nil {
- ce.Log.Err(err).Msg("Failed to resolve identifier")
ce.Reply("Failed to resolve identifier: %v", err)
return
} else if resp == nil {
ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true)
return
}
- formattedName := formatResolveIdentifierResult(ce.Ctx, resp)
+ formattedName := formatResolveIdentifierResult(resp)
if createChat {
- if resp.Chat == nil {
- ce.Reply("Interface error: network connector did not return chat for create chat request")
- return
+ name := resp.Portal.Name
+ if name == "" {
+ name = resp.Portal.MXID.String()
}
- portal := resp.Chat.Portal
- if portal == nil {
- portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get portal")
- ce.Reply("Failed to get portal: %v", err)
- return
- }
- }
- if resp.Chat.PortalInfo == nil {
- resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get portal info")
- ce.Reply("Failed to get portal info: %v", err)
- return
- }
- }
- if portal.MXID != "" {
- name := portal.Name
- if name == "" {
- name = portal.MXID.String()
- }
- portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{})
- ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
+ if !resp.JustCreated {
+ ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
} else {
- err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to create room")
- ce.Reply("Failed to create room: %v", err)
- return
- }
- name := portal.Name
- if name == "" {
- name = portal.MXID.String()
- }
- ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
+ ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
}
} else {
ce.Reply("Found %s", formattedName)
}
}
+var CommandCreateGroup = &FullHandler{
+ Func: fnCreateGroup,
+ Name: "create-group",
+ Aliases: []string{"create"},
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Create a new group chat for the current Matrix room",
+ Args: "[_group type_]",
+ },
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
+}
+
+func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
+ evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
+ } else if evt != nil {
+ content, _ = evt.Content.Parsed.(T)
+ }
+ return
+}
+
+func fnCreateGroup(ce *Event) {
+ ce.Bridge.Matrix.GetCapabilities()
+ login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
+ if api == nil {
+ return
+ }
+ stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ ce.Reply("Matrix connector doesn't support fetching room state")
+ return
+ }
+ members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get room members for group creation")
+ ce.Reply("Failed to get room members: %v", err)
+ return
+ }
+ caps := ce.Bridge.Network.GetCapabilities()
+ params := &bridgev2.GroupCreateParams{
+ Username: "",
+ Participants: make([]networkid.UserID, 0, len(members)-2),
+ Parent: nil, // TODO check space parent event
+ Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
+ Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
+ Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
+ Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
+ RoomID: ce.RoomID,
+ }
+ for userID, member := range members {
+ if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
+ continue
+ }
+ if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
+ params.Participants = append(params.Participants, parsedUserID)
+ } else if !ce.Bridge.Config.SplitPortals {
+ if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
+ ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
+ } else if user != nil {
+ // TODO add user logins to participants
+ //for _, login := range user.GetUserLogins() {
+ // params.Participants = append(params.Participants, login.GetUserID())
+ //}
+ }
+ }
+ }
+
+ if len(caps.Provisioning.GroupCreation) == 0 {
+ ce.Reply("No group creation types defined in network capabilities")
+ return
+ } else if len(remainingArgs) > 0 {
+ params.Type = remainingArgs[0]
+ } else if len(caps.Provisioning.GroupCreation) == 1 {
+ for params.Type = range caps.Provisioning.GroupCreation {
+ // The loop assigns the variable we want
+ }
+ } else {
+ types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
+ ce.Reply("Please specify type of group to create: `%s`", types)
+ return
+ }
+ resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
+ if err != nil {
+ ce.Reply("Failed to create group: %v", err)
+ return
+ }
+ var postfix string
+ if len(resp.FailedParticipants) > 0 {
+ failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
+ i := 0
+ for participantID, meta := range resp.FailedParticipants {
+ failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
+ i++
+ }
+ postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
+ }
+ ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
+}
+
var CommandSearch = &FullHandler{
Func: fnSearch,
Name: "search",
@@ -163,35 +267,67 @@ func fnSearch(ce *Event) {
ce.Reply("Usage: `$cmdprefix search `")
return
}
- _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
+ login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
if api == nil {
return
}
- results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " "))
+ resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " "))
if err != nil {
- ce.Log.Err(err).Msg("Failed to search for users")
ce.Reply("Failed to search for users: %v", err)
return
}
- resultsString := make([]string, len(results))
- for i, res := range results {
- formattedName := formatResolveIdentifierResult(ce.Ctx, res)
+ resultsString := make([]string, len(resp.Results))
+ for i, res := range resp.Results {
+ formattedName := formatResolveIdentifierResult(res)
resultsString[i] = fmt.Sprintf("* %s", formattedName)
- if res.Chat != nil {
- if res.Chat.Portal == nil {
- res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey)
- if err != nil {
- ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal")
- }
- }
- if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" {
- portalName := res.Chat.Portal.Name
- if portalName == "" {
- portalName = res.Chat.Portal.MXID.String()
- }
- resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL())
+ if res.Portal != nil && res.Portal.MXID != "" {
+ portalName := res.Portal.Name
+ if portalName == "" {
+ portalName = res.Portal.MXID.String()
}
+ resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL())
}
}
ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n"))
}
+
+var CommandMute = &FullHandler{
+ Func: fnMute,
+ Name: "mute",
+ Aliases: []string{"unmute"},
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Mute or unmute a chat on the remote network",
+ Args: "[duration]",
+ },
+ RequiresPortal: true,
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI],
+}
+
+func fnMute(ce *Event) {
+ _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats")
+ var mutedUntil int64
+ if ce.Command == "mute" {
+ mutedUntil = -1
+ if len(ce.Args) > 0 {
+ duration, err := time.ParseDuration(ce.Args[0])
+ if err != nil {
+ ce.Reply("Invalid duration: %v", err)
+ return
+ }
+ mutedUntil = time.Now().Add(duration).UnixMilli()
+ }
+ }
+ err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{
+ MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{
+ Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil},
+ Portal: ce.Portal,
+ },
+ })
+ if err != nil {
+ ce.Reply("Failed to %s chat: %v", ce.Command, err)
+ } else {
+ ce.React("✅️")
+ }
+}
diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go
index 224ae626..1f920640 100644
--- a/bridgev2/database/backfillqueue.go
+++ b/bridgev2/database/backfillqueue.go
@@ -78,6 +78,11 @@ const (
dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11
WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3
`
+ markBackfillTaskNotDoneQuery = `
+ UPDATE backfill_task
+ SET is_done = false
+ WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND user_login_id = $4
+ `
getNextBackfillQuery = `
SELECT
bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done,
@@ -127,6 +132,10 @@ func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) erro
return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...)
}
+func (btq *BackfillTaskQuery) MarkNotDone(ctx context.Context, portalKey networkid.PortalKey, userLoginID networkid.UserLoginID) error {
+ return btq.Exec(ctx, markBackfillTaskNotDoneQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver, userLoginID)
+}
+
func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) {
return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano())
}
diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go
index f1789441..05abddf0 100644
--- a/bridgev2/database/database.go
+++ b/bridgev2/database/database.go
@@ -7,13 +7,7 @@
package database
import (
- "encoding/json"
- "reflect"
- "strings"
-
"go.mau.fi/util/dbutil"
- "golang.org/x/exp/constraints"
- "golang.org/x/exp/maps"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -34,6 +28,7 @@ type Database struct {
UserPortal *UserPortalQuery
BackfillTask *BackfillTaskQuery
KV *KVQuery
+ PublicMedia *PublicMediaQuery
}
type MetaMerger interface {
@@ -141,6 +136,12 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa
BridgeID: bridgeID,
Database: db,
},
+ PublicMedia: &PublicMediaQuery{
+ BridgeID: bridgeID,
+ QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia {
+ return &PublicMedia{}
+ }),
+ },
}
}
@@ -151,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
panic("bridge ID mismatch")
}
}
-
-func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
- if val, found := m[key]; found {
- floatVal, ok := val.(float64)
- if ok {
- return T(floatVal), true
- }
- tVal, ok := val.(T)
- if ok {
- return tVal, true
- }
- }
- return 0, false
-}
-
-func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
- err := json.Unmarshal(input, data)
- if err != nil {
- return err
- }
- err = json.Unmarshal(input, extra)
- if err != nil {
- return err
- }
- if *extra == nil {
- *extra = make(map[string]any)
- }
- return nil
-}
-
-func marshalMerge(data any, extra map[string]any) ([]byte, error) {
- if extra == nil {
- return json.Marshal(data)
- }
- merged := make(map[string]any)
- maps.Copy(merged, extra)
- dataRef := reflect.ValueOf(data).Elem()
- dataType := dataRef.Type()
- for _, field := range reflect.VisibleFields(dataType) {
- parts := strings.Split(field.Tag.Get("json"), ",")
- if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
- continue
- }
- fieldVal := dataRef.FieldByIndex(field.Index)
- if fieldVal.IsZero() {
- delete(merged, parts[0])
- } else {
- merged[parts[0]] = fieldVal.Interface()
- }
- }
- return json.Marshal(merged)
-}
diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go
index 23db1448..df36b205 100644
--- a/bridgev2/database/disappear.go
+++ b/bridgev2/database/disappear.go
@@ -12,56 +12,94 @@ import (
"time"
"go.mau.fi/util/dbutil"
+ "go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
-// DisappearingType represents the type of a disappearing message timer.
-type DisappearingType string
+// Deprecated: use [event.DisappearingType]
+type DisappearingType = event.DisappearingType
+// Deprecated: use constants in event package
const (
- DisappearingTypeNone DisappearingType = ""
- DisappearingTypeAfterRead DisappearingType = "after_read"
- DisappearingTypeAfterSend DisappearingType = "after_send"
+ DisappearingTypeNone = event.DisappearingTypeNone
+ DisappearingTypeAfterRead = event.DisappearingTypeAfterRead
+ DisappearingTypeAfterSend = event.DisappearingTypeAfterSend
)
// DisappearingSetting represents a disappearing message timer setting
// by combining a type with a timer and an optional start timestamp.
type DisappearingSetting struct {
- Type DisappearingType
+ Type event.DisappearingType
Timer time.Duration
DisappearAt time.Time
}
+func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting {
+ if evt == nil || evt.Type == event.DisappearingTypeNone {
+ return DisappearingSetting{}
+ }
+ return DisappearingSetting{
+ Type: evt.Type,
+ Timer: evt.Timer.Duration,
+ }
+}
+
+func (ds DisappearingSetting) Normalize() DisappearingSetting {
+ if ds.Type == event.DisappearingTypeNone {
+ ds.Timer = 0
+ } else if ds.Timer == 0 {
+ ds.Type = event.DisappearingTypeNone
+ }
+ return ds
+}
+
+func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting {
+ ds.DisappearAt = start.Add(ds.Timer)
+ return ds
+}
+
+func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer {
+ if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 {
+ return &event.BeeperDisappearingTimer{}
+ }
+ return &event.BeeperDisappearingTimer{
+ Type: ds.Type,
+ Timer: jsontime.MS(ds.Timer),
+ }
+}
+
type DisappearingMessageQuery struct {
BridgeID networkid.BridgeID
*dbutil.QueryHelper[*DisappearingMessage]
}
type DisappearingMessage struct {
- BridgeID networkid.BridgeID
- RoomID id.RoomID
- EventID id.EventID
+ BridgeID networkid.BridgeID
+ RoomID id.RoomID
+ EventID id.EventID
+ Timestamp time.Time
DisappearingSetting
}
const (
upsertDisappearingMessageQuery = `
- INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at)
- VALUES ($1, $2, $3, $4, $5, $6)
+ INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at
`
startDisappearingMessagesQuery = `
UPDATE disappearing_message
SET disappear_at=$1 + timer
- WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read'
- RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at
+ WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4
+ RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
`
getUpcomingDisappearingMessagesQuery = `
- SELECT bridge_id, mx_room, mxid, type, timer, disappear_at
+ SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2
- ORDER BY disappear_at
+ ORDER BY disappear_at LIMIT $3
`
deleteDisappearingMessageQuery = `
DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2
@@ -73,12 +111,12 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe
return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...)
}
-func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) {
- return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID)
+func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) {
+ return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano())
}
-func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
- return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano())
+func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) {
+ return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit)
}
func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error {
@@ -86,17 +124,19 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even
}
func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
+ var timestamp int64
var disappearAt sql.NullInt64
- err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt)
+ err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt)
if err != nil {
return nil, err
}
if disappearAt.Valid {
d.DisappearAt = time.Unix(0, disappearAt.Int64)
}
+ d.Timestamp = time.Unix(0, timestamp)
return d, nil
}
func (d *DisappearingMessage) sqlVariables() []any {
- return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
+ return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
}
diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go
index c32929ad..16af35ca 100644
--- a/bridgev2/database/ghost.go
+++ b/bridgev2/database/ghost.go
@@ -7,12 +7,17 @@
package database
import (
+ "bytes"
"context"
"encoding/hex"
+ "encoding/json"
+ "fmt"
"go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
)
@@ -22,6 +27,55 @@ type GhostQuery struct {
*dbutil.QueryHelper[*Ghost]
}
+type ExtraProfile map[string]json.RawMessage
+
+func (ep *ExtraProfile) Set(key string, value any) error {
+ if key == "displayname" || key == "avatar_url" {
+ return fmt.Errorf("cannot set reserved profile key %q", key)
+ }
+ marshaled, err := json.Marshal(value)
+ if err != nil {
+ return err
+ }
+ if *ep == nil {
+ *ep = make(ExtraProfile)
+ }
+ (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled)
+ return nil
+}
+
+func (ep *ExtraProfile) With(key string, value any) *ExtraProfile {
+ exerrors.PanicIfNotNil(ep.Set(key, value))
+ return ep
+}
+
+func canonicalizeIfObject(data json.RawMessage) json.RawMessage {
+ if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
+ return canonicaljson.CanonicalJSONAssumeValid(data)
+ }
+ return data
+}
+
+func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) {
+ if len(*ep) == 0 {
+ return
+ }
+ if *dest == nil {
+ *dest = make(ExtraProfile)
+ }
+ for key, val := range *ep {
+ if key == "displayname" || key == "avatar_url" {
+ continue
+ }
+ existing, exists := (*dest)[key]
+ if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) {
+ (*dest)[key] = val
+ changed = true
+ }
+ }
+ return
+}
+
type Ghost struct {
BridgeID networkid.BridgeID
ID networkid.UserID
@@ -35,13 +89,14 @@ type Ghost struct {
ContactInfoSet bool
IsBot bool
Identifiers []string
+ ExtraProfile ExtraProfile
Metadata any
}
const (
getGhostBaseQuery = `
SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
+ name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
FROM ghost
`
getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2`
@@ -49,13 +104,14 @@ const (
insertGhostQuery = `
INSERT INTO ghost (
bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
+ name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
`
updateGhostQuery = `
UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6,
- name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12
+ name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10,
+ identifiers=$11, extra_profile=$12, metadata=$13
WHERE bridge_id=$1 AND id=$2
`
)
@@ -86,7 +142,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) {
&g.BridgeID, &g.ID,
&g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC,
&g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot,
- dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
+ dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
)
if err != nil {
return nil, err
@@ -116,6 +172,6 @@ func (g *Ghost) sqlVariables() []any {
g.BridgeID, g.ID,
g.Name, g.AvatarID, avatarHash, g.AvatarMXC,
g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot,
- dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
+ dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
}
}
diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go
index 5a1af019..bca26ed5 100644
--- a/bridgev2/database/kvstore.go
+++ b/bridgev2/database/kvstore.go
@@ -20,8 +20,10 @@ import (
type Key string
const (
- KeySplitPortalsEnabled Key = "split_portals_enabled"
- KeyBridgeInfoVersion Key = "bridge_info_version"
+ KeySplitPortalsEnabled Key = "split_portals_enabled"
+ KeyBridgeInfoVersion Key = "bridge_info_version"
+ KeyEncryptionStateResynced Key = "encryption_state_resynced"
+ KeyRecoveryKey Key = "recovery_key"
)
type KVQuery struct {
diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go
index 9b3b1493..4fd599a8 100644
--- a/bridgev2/database/message.go
+++ b/bridgev2/database/message.go
@@ -11,9 +11,12 @@ import (
"crypto/sha256"
"database/sql"
"encoding/base64"
+ "fmt"
"strings"
+ "sync"
"time"
+ "github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -24,6 +27,7 @@ type MessageQuery struct {
BridgeID networkid.BridgeID
MetaType MetaTypeCreator
*dbutil.QueryHelper[*Message]
+ chunkDeleteLock sync.Mutex
}
type Message struct {
@@ -64,8 +68,8 @@ const (
getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`
getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5`
getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1`
- getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1`
- getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1`
+ getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1`
+ getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1`
getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4`
getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1`
@@ -96,6 +100,10 @@ const (
deleteMessagePartByRowIDQuery = `
DELETE FROM message WHERE bridge_id=$1 AND rowid=$2
`
+ deleteMessageChunkQuery = `
+ DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5
+ `
+ getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1`
)
func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) {
@@ -180,6 +188,85 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error {
return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID)
}
+func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) {
+ res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID)
+ if err != nil {
+ return 0, err
+ }
+ return res.RowsAffected()
+}
+
+func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) {
+ err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID)
+ return
+}
+
+const deleteChunkSize = 100_000
+
+func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error {
+ if mq.GetDB().Dialect != dbutil.SQLite {
+ return nil
+ }
+ log := zerolog.Ctx(ctx).With().
+ Str("action", "delete messages in chunks").
+ Stringer("portal_key", portal).
+ Logger()
+ if !mq.chunkDeleteLock.TryLock() {
+ log.Warn().Msg("Portal deletion lock is being held, waiting...")
+ mq.chunkDeleteLock.Lock()
+ log.Debug().Msg("Acquired portal deletion lock after waiting")
+ }
+ defer mq.chunkDeleteLock.Unlock()
+ total, err := mq.CountMessagesInPortal(ctx, portal)
+ if err != nil {
+ return fmt.Errorf("failed to count messages in portal: %w", err)
+ } else if total < deleteChunkSize/3 {
+ return nil
+ }
+ globalMaxRowID, err := mq.getMaxRowID(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get max row ID: %w", err)
+ }
+ log.Debug().
+ Int("total_count", total).
+ Int64("global_max_row_id", globalMaxRowID).
+ Msg("Portal has lots of messages, deleting in chunks to avoid database locks")
+ maxRowID := int64(deleteChunkSize)
+ globalMaxRowID += deleteChunkSize * 1.2
+ var dbTimeUsed time.Duration
+ globalStart := time.Now()
+ for total > 500 && maxRowID < globalMaxRowID {
+ start := time.Now()
+ count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID)
+ duration := time.Since(start)
+ dbTimeUsed += duration
+ if err != nil {
+ return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err)
+ }
+ total -= int(count)
+ maxRowID += deleteChunkSize
+ sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond))
+ log.Debug().
+ Int64("max_row_id", maxRowID).
+ Int64("deleted_count", count).
+ Int("remaining_count", total).
+ Dur("duration", duration).
+ Dur("sleep_time", sleepTime).
+ Msg("Deleted chunk of messages")
+ select {
+ case <-time.After(sleepTime):
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ log.Debug().
+ Int("remaining_count", total).
+ Dur("db_time_used", dbTimeUsed).
+ Dur("total_duration", time.Since(globalStart)).
+ Msg("Finished chunked delete of messages in portal")
+ return nil
+}
+
func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) {
err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count)
return
diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go
index 17e44b09..0e6be286 100644
--- a/bridgev2/database/portal.go
+++ b/bridgev2/database/portal.go
@@ -16,6 +16,7 @@ import (
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -34,9 +35,20 @@ type PortalQuery struct {
*dbutil.QueryHelper[*Portal]
}
+type CapStateFlags uint32
+
+func (csf CapStateFlags) Has(flag CapStateFlags) bool {
+ return csf&flag != 0
+}
+
+const (
+ CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota
+)
+
type CapabilityState struct {
Source networkid.UserLoginID `json:"source"`
ID string `json:"id"`
+ Flags CapStateFlags `json:"flags"`
}
type Portal struct {
@@ -44,30 +56,31 @@ type Portal struct {
networkid.PortalKey
MXID id.RoomID
- ParentKey networkid.PortalKey
- RelayLoginID networkid.UserLoginID
- OtherUserID networkid.UserID
- Name string
- Topic string
- AvatarID networkid.AvatarID
- AvatarHash [32]byte
- AvatarMXC id.ContentURIString
- NameSet bool
- TopicSet bool
- AvatarSet bool
- NameIsCustom bool
- InSpace bool
- RoomType RoomType
- Disappear DisappearingSetting
- CapState CapabilityState
- Metadata any
+ ParentKey networkid.PortalKey
+ RelayLoginID networkid.UserLoginID
+ OtherUserID networkid.UserID
+ Name string
+ Topic string
+ AvatarID networkid.AvatarID
+ AvatarHash [32]byte
+ AvatarMXC id.ContentURIString
+ NameSet bool
+ TopicSet bool
+ AvatarSet bool
+ NameIsCustom bool
+ InSpace bool
+ MessageRequest bool
+ RoomType RoomType
+ Disappear DisappearingSetting
+ CapState CapabilityState
+ Metadata any
}
const (
getPortalBaseQuery = `
SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
- name_set, topic_set, avatar_set, name_is_custom, in_space,
+ name_set, topic_set, avatar_set, name_is_custom, in_space, message_request,
room_type, disappear_type, disappear_timer, cap_state,
metadata
FROM portal
@@ -76,7 +89,9 @@ const (
getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')`
getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL`
+ getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC`
getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2`
+ getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3`
getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1`
getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3`
@@ -87,11 +102,11 @@ const (
bridge_id, id, receiver, mxid,
parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, topic_set, name_is_custom, in_space,
+ name_set, avatar_set, topic_set, name_is_custom, in_space, message_request,
room_type, disappear_type, disappear_timer, cap_state,
metadata, relay_bridge_id
) VALUES (
- $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23,
+ $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24,
CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END
)
`
@@ -100,8 +115,8 @@ const (
SET mxid=$4, parent_id=$5, parent_receiver=$6,
relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END,
other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13,
- name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18,
- room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23
+ name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19,
+ room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24
WHERE bridge_id=$1 AND id=$2 AND receiver=$3
`
deletePortalQuery = `
@@ -111,15 +126,33 @@ const (
reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3`
migrateToSplitPortalsQuery = `
UPDATE portal
- SET receiver=COALESCE((
- SELECT login_id
- FROM user_portal
- WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
- LIMIT 1
- ), (
- SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
- ), '')
- WHERE receiver='' AND bridge_id=$1
+ SET receiver=new_receiver
+ FROM (
+ SELECT bridge_id, id, COALESCE((
+ SELECT login_id
+ FROM user_portal
+ WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
+ LIMIT 1
+ ), (
+ SELECT login_id
+ FROM user_portal
+ WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id
+ LIMIT 1
+ ), (
+ SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
+ ), '') AS new_receiver
+ FROM portal
+ WHERE receiver='' AND bridge_id=$1
+ ) updates
+ WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS (
+ SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver
+ )
+ `
+ fixParentsAfterSplitPortalMigrationQuery = `
+ UPDATE portal
+ SET parent_receiver=receiver
+ WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''
+ AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver);
`
)
@@ -147,6 +180,10 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID)
}
+func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) {
+ return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID)
+}
+
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID)
}
@@ -155,6 +192,10 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid.
return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID)
}
+func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) {
+ return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID)
+}
+
func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) {
return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver)
}
@@ -185,6 +226,14 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error)
return res.RowsAffected()
}
+func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) {
+ res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID)
+ if err != nil {
+ return 0, err
+ }
+ return res.RowsAffected()
+}
+
func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString
var disappearTimer sql.NullInt64
@@ -193,7 +242,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
&p.BridgeID, &p.ID, &p.Receiver, &mxid,
&parentID, &parentReceiver, &relayLoginID, &otherUserID,
&p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC,
- &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace,
+ &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest,
&p.RoomType, &disappearType, &disappearTimer,
dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata},
)
@@ -208,7 +257,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
}
if disappearType.Valid {
p.Disappear = DisappearingSetting{
- Type: DisappearingType(disappearType.String),
+ Type: event.DisappearingType(disappearType.String),
Timer: time.Duration(disappearTimer.Int64),
}
}
@@ -240,7 +289,7 @@ func (p *Portal) sqlVariables() []any {
p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID),
dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID),
p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC,
- p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace,
+ p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest,
p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer),
dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata},
}
diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go
new file mode 100644
index 00000000..b667399c
--- /dev/null
+++ b/bridgev2/database/publicmedia.go
@@ -0,0 +1,72 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package database
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "go.mau.fi/util/dbutil"
+
+ "maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/crypto/attachment"
+ "maunium.net/go/mautrix/id"
+)
+
+type PublicMediaQuery struct {
+ BridgeID networkid.BridgeID
+ *dbutil.QueryHelper[*PublicMedia]
+}
+
+type PublicMedia struct {
+ BridgeID networkid.BridgeID
+ PublicID string
+ MXC id.ContentURI
+ Keys *attachment.EncryptedFile
+ MimeType string
+ Expiry time.Time
+}
+
+const (
+ upsertPublicMediaQuery = `
+ INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry)
+ VALUES ($1, $2, $3, $4, $5, $6)
+ ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry
+ `
+ getPublicMediaQuery = `
+ SELECT bridge_id, public_id, mxc, keys, mimetype, expiry
+ FROM public_media WHERE bridge_id=$1 AND public_id=$2
+ `
+)
+
+func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error {
+ ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID)
+ return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...)
+}
+
+func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) {
+ return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID)
+}
+
+func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) {
+ var expiry sql.NullInt64
+ var mimetype sql.NullString
+ err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry)
+ if err != nil {
+ return nil, err
+ }
+ if expiry.Valid {
+ pm.Expiry = time.Unix(0, expiry.Int64)
+ }
+ pm.MimeType = mimetype.String
+ return pm, nil
+}
+
+func (pm *PublicMedia) sqlVariables() []any {
+ return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)}
+}
diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql
index 4eea05bb..6092dc24 100644
--- a/bridgev2/database/upgrades/00-latest.sql
+++ b/bridgev2/database/upgrades/00-latest.sql
@@ -1,4 +1,4 @@
--- v0 -> v22 (compatible with v9+): Latest revision
+-- v0 -> v27 (compatible with v9+): Latest revision
CREATE TABLE "user" (
bridge_id TEXT NOT NULL,
mxid TEXT NOT NULL,
@@ -48,6 +48,7 @@ CREATE TABLE portal (
topic_set BOOLEAN NOT NULL,
name_is_custom BOOLEAN NOT NULL DEFAULT false,
in_space BOOLEAN NOT NULL,
+ message_request BOOLEAN NOT NULL DEFAULT false,
room_type TEXT NOT NULL,
disappear_type TEXT,
disappear_timer BIGINT,
@@ -64,6 +65,7 @@ CREATE TABLE portal (
ON DELETE SET NULL ON UPDATE CASCADE
);
CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid);
+CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
CREATE TABLE ghost (
bridge_id TEXT NOT NULL,
@@ -78,6 +80,7 @@ CREATE TABLE ghost (
contact_info_set BOOLEAN NOT NULL,
is_bot BOOLEAN NOT NULL,
identifiers jsonb NOT NULL,
+ extra_profile jsonb,
metadata jsonb NOT NULL,
PRIMARY KEY (bridge_id, id)
@@ -127,6 +130,7 @@ CREATE TABLE disappearing_message (
bridge_id TEXT NOT NULL,
mx_room TEXT NOT NULL,
mxid TEXT NOT NULL,
+ timestamp BIGINT NOT NULL DEFAULT 0,
type TEXT NOT NULL,
timer BIGINT NOT NULL,
disappear_at BIGINT,
@@ -137,6 +141,7 @@ CREATE TABLE disappearing_message (
REFERENCES portal (bridge_id, mxid)
ON DELETE CASCADE
);
+CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
CREATE TABLE reaction (
bridge_id TEXT NOT NULL,
@@ -215,3 +220,14 @@ CREATE TABLE kv_store (
PRIMARY KEY (bridge_id, key)
);
+
+CREATE TABLE public_media (
+ bridge_id TEXT NOT NULL,
+ public_id TEXT NOT NULL,
+ mxc TEXT NOT NULL,
+ keys jsonb,
+ mimetype TEXT,
+ expiry BIGINT,
+
+ PRIMARY KEY (bridge_id, public_id)
+);
diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql
new file mode 100644
index 00000000..ecd00b8d
--- /dev/null
+++ b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql
@@ -0,0 +1,2 @@
+-- v23 (compatible with v9+): Add event timestamp for disappearing messages
+ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0;
diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql
new file mode 100644
index 00000000..c4290090
--- /dev/null
+++ b/bridgev2/database/upgrades/24-public-media.sql
@@ -0,0 +1,11 @@
+-- v24 (compatible with v9+): Custom URLs for public media
+CREATE TABLE public_media (
+ bridge_id TEXT NOT NULL,
+ public_id TEXT NOT NULL,
+ mxc TEXT NOT NULL,
+ keys jsonb,
+ mimetype TEXT,
+ expiry BIGINT,
+
+ PRIMARY KEY (bridge_id, public_id)
+);
diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql
new file mode 100644
index 00000000..b9d82a7a
--- /dev/null
+++ b/bridgev2/database/upgrades/25-message-requests.sql
@@ -0,0 +1,2 @@
+-- v25 (compatible with v9+): Flag for message request portals
+ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false;
diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql
new file mode 100644
index 00000000..ae5d8cad
--- /dev/null
+++ b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql
@@ -0,0 +1,3 @@
+-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents
+CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
+CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql
new file mode 100644
index 00000000..e8e0549a
--- /dev/null
+++ b/bridgev2/database/upgrades/27-ghost-extra-profile.sql
@@ -0,0 +1,2 @@
+-- v27 (compatible with v9+): Add column for extra ghost profile metadata
+ALTER TABLE ghost ADD COLUMN extra_profile jsonb;
diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go
index 9fa6569a..00ff01c9 100644
--- a/bridgev2/database/userlogin.go
+++ b/bridgev2/database/userlogin.go
@@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin {
func (u *UserLogin) sqlVariables() []any {
var remoteProfile dbutil.JSON
- if !u.RemoteProfile.IsEmpty() {
+ if !u.RemoteProfile.IsZero() {
remoteProfile.Data = &u.RemoteProfile
}
return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}}
diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go
index 278b236b..e928a4c7 100644
--- a/bridgev2/database/userportal.go
+++ b/bridgev2/database/userportal.go
@@ -67,6 +67,9 @@ const (
markLoginAsPreferredQuery = `
UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5
`
+ markAllNotInSpaceQuery = `
+ UPDATE user_portal SET in_space=false WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3
+ `
deleteUserPortalQuery = `
DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5
`
@@ -110,6 +113,10 @@ func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogi
return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver)
}
+func (upq *UserPortalQuery) MarkAllNotInSpace(ctx context.Context, portal networkid.PortalKey) error {
+ return upq.Exec(ctx, markAllNotInSpaceQuery, upq.BridgeID, portal.ID, portal.Receiver)
+}
+
func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error {
return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver)
}
diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go
index 1d063088..b5c37e8f 100644
--- a/bridgev2/disappear.go
+++ b/bridgev2/disappear.go
@@ -21,7 +21,7 @@ import (
type DisappearLoop struct {
br *Bridge
- NextCheck time.Time
+ nextCheck atomic.Pointer[time.Time]
stop atomic.Pointer[context.CancelFunc]
}
@@ -35,15 +35,30 @@ func (dl *DisappearLoop) Start() {
}
log.Debug().Msg("Disappearing message loop starting")
for {
- dl.NextCheck = time.Now().Add(DisappearCheckInterval)
- messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval)
+ nextCheck := time.Now().Add(DisappearCheckInterval)
+ dl.nextCheck.Store(&nextCheck)
+ const MessageLimit = 200
+ messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit)
if err != nil {
log.Err(err).Msg("Failed to get upcoming disappearing messages")
} else if len(messages) > 0 {
+ if len(messages) >= MessageLimit {
+ lastDisappearTime := messages[len(messages)-1].DisappearAt
+ log.Debug().
+ Int("message_count", len(messages)).
+ Time("last_due", lastDisappearTime).
+ Msg("Deleting disappearing messages synchronously and checking again immediately")
+ // Store the expected next check time to avoid Add spawning unnecessary goroutines.
+ // This can be in the past, in which case Add will put everything in the db, which is also fine.
+ dl.nextCheck.Store(&lastDisappearTime)
+ // If there are many messages, process them synchronously and then check again.
+ dl.sleepAndDisappear(ctx, messages...)
+ continue
+ }
go dl.sleepAndDisappear(ctx, messages...)
}
select {
- case <-time.After(time.Until(dl.NextCheck)):
+ case <-time.After(time.Until(dl.GetNextCheck())):
case <-ctx.Done():
log.Debug().Msg("Disappearing message loop stopping")
return
@@ -51,6 +66,17 @@ func (dl *DisappearLoop) Start() {
}
}
+func (dl *DisappearLoop) GetNextCheck() time.Time {
+ if dl == nil {
+ return time.Time{}
+ }
+ nextCheck := dl.nextCheck.Load()
+ if nextCheck == nil {
+ return time.Time{}
+ }
+ return *nextCheck
+}
+
func (dl *DisappearLoop) Stop() {
if dl == nil {
return
@@ -60,14 +86,14 @@ func (dl *DisappearLoop) Stop() {
}
}
-func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) {
- startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID)
+func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) {
+ startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages")
return
}
startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool {
- return dm.DisappearAt.After(dl.NextCheck)
+ return dm.DisappearAt.After(dl.GetNextCheck())
})
slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int {
return a.DisappearAt.Compare(b.DisappearAt)
@@ -84,17 +110,24 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa
Stringer("event_id", dm.EventID).
Msg("Failed to save disappearing message")
}
- if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) {
+ if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.GetNextCheck()) {
go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm)
}
}
func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) {
for _, msg := range dms {
- select {
- case <-time.After(time.Until(msg.DisappearAt)):
- case <-ctx.Done():
- return
+ timeUntilDisappear := time.Until(msg.DisappearAt)
+ if timeUntilDisappear <= 0 {
+ if ctx.Err() != nil {
+ return
+ }
+ } else {
+ select {
+ case <-time.After(timeUntilDisappear):
+ case <-ctx.Done():
+ return
+ }
}
resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{
Parsed: &event.RedactionEventContent{
diff --git a/bridgev2/errors.go b/bridgev2/errors.go
index c023dcdf..f6677d2e 100644
--- a/bridgev2/errors.go
+++ b/bridgev2/errors.go
@@ -38,35 +38,53 @@ var ErrNotLoggedIn = errors.New("not logged in")
// but direct media is not enabled.
var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled")
+var ErrPortalIsDeleted = errors.New("portal is deleted")
+var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event")
+
// Common message status errors
var (
- ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
- ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage()
- ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage()
- ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage()
- ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage()
- ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage()
- ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage()
- ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage()
- ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage()
- ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
- ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage()
- ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
- ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
- ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
- ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
- ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
- ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
- ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
- ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
- ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
- ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
- ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
- ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
- ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
+ ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
+ ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
+ ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
+ ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
+ ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
+ ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
+ ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
+ ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
+ ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
+ ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
+ ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
+
+ ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
+ ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
+ ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
+
+ ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true)
)
// Common login interface errors
diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go
index f06c0363..590dd1dc 100644
--- a/bridgev2/ghost.go
+++ b/bridgev2/ghost.go
@@ -9,12 +9,15 @@ package bridgev2
import (
"context"
"crypto/sha256"
+ "encoding/json"
"fmt"
+ "maps"
"net/http"
+ "slices"
"github.com/rs/zerolog"
+ "go.mau.fi/util/exerrors"
"go.mau.fi/util/exmime"
- "golang.org/x/exp/slices"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -134,10 +137,11 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32
}
type UserInfo struct {
- Identifiers []string
- Name *string
- Avatar *Avatar
- IsBot *bool
+ Identifiers []string
+ Name *string
+ Avatar *Avatar
+ IsBot *bool
+ ExtraProfile database.ExtraProfile
ExtraUpdates ExtraUpdater[*Ghost]
}
@@ -158,7 +162,7 @@ func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool {
}
func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
- if ghost.AvatarID == avatar.ID && ghost.AvatarSet {
+ if ghost.AvatarID == avatar.ID && (avatar.Remove || ghost.AvatarMXC != "") && ghost.AvatarSet {
return false
}
ghost.AvatarID = avatar.ID
@@ -168,7 +172,7 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
ghost.AvatarSet = false
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar")
return true
- } else if newHash == ghost.AvatarHash && ghost.AvatarSet {
+ } else if newHash == ghost.AvatarHash && ghost.AvatarMXC != "" && ghost.AvatarSet {
return true
}
ghost.AvatarHash = newHash
@@ -185,9 +189,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
return true
}
-func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
+func (ghost *Ghost) getExtraProfileMeta() any {
bridgeName := ghost.Bridge.Network.GetName()
- return &event.BeeperProfileExtra{
+ baseExtra := &event.BeeperProfileExtra{
RemoteID: string(ghost.ID),
Identifiers: ghost.Identifiers,
Service: bridgeName.BeeperBridgeType,
@@ -195,23 +199,35 @@ func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
IsBridgeBot: false,
IsNetworkBot: ghost.IsBot,
}
+ if len(ghost.ExtraProfile) == 0 {
+ return baseExtra
+ }
+ mergedExtra := maps.Clone(ghost.ExtraProfile)
+ baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra))
+ exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra))
+ return mergedExtra
}
-func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool {
- if identifiers != nil {
- slices.Sort(identifiers)
- }
- if ghost.ContactInfoSet &&
- (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) &&
- (isBot == nil || *isBot == ghost.IsBot) {
+func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool {
+ if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta {
+ ghost.ContactInfoSet = false
return false
}
if identifiers != nil {
+ slices.Sort(identifiers)
+ }
+ changed := extraProfile.CopyTo(&ghost.ExtraProfile)
+ if identifiers != nil {
+ changed = changed || !slices.Equal(identifiers, ghost.Identifiers)
ghost.Identifiers = identifiers
}
if isBot != nil {
+ changed = changed || *isBot != ghost.IsBot
ghost.IsBot = *isBot
}
+ if ghost.ContactInfoSet && !changed {
+ return false
+ }
err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta())
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata")
@@ -234,7 +250,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool {
}
func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) {
- if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
+ if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
return
}
info, err := source.Client.GetUserInfo(ctx, ghost)
@@ -244,12 +260,16 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin
zerolog.Ctx(ctx).Debug().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
+ Bool("has_avatar", ghost.AvatarMXC != "").
+ Bool("avatar_set", ghost.AvatarSet).
Msg("Updating ghost info in IfNecessary call")
ghost.UpdateInfo(ctx, info)
} else {
zerolog.Ctx(ctx).Trace().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
+ Bool("has_avatar", ghost.AvatarMXC != "").
+ Bool("avatar_set", ghost.AvatarSet).
Msg("No ghost info received in IfNecessary call")
}
}
@@ -277,9 +297,14 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) {
}
if info.Avatar != nil {
update = ghost.UpdateAvatar(ctx, info.Avatar) || update
+ } else if oldAvatar == "" && !ghost.AvatarSet {
+ // Special case: nil avatar means we're not expecting one ever, if we don't currently have
+ // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary.
+ ghost.AvatarSet = true
+ update = true
}
- if info.Identifiers != nil || info.IsBot != nil {
- update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update
+ if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil {
+ update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update
}
if info.ExtraUpdates != nil {
update = info.ExtraUpdates(ctx, ghost) || update
diff --git a/bridgev2/login.go b/bridgev2/login.go
index 1fa3afbc..b8321719 100644
--- a/bridgev2/login.go
+++ b/bridgev2/login.go
@@ -13,6 +13,7 @@ import (
"strings"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
)
// LoginProcess represents a single occurrence of a user logging into the remote network.
@@ -178,6 +179,8 @@ const (
LoginInputFieldTypeToken LoginInputFieldType = "token"
LoginInputFieldTypeURL LoginInputFieldType = "url"
LoginInputFieldTypeDomain LoginInputFieldType = "domain"
+ LoginInputFieldTypeSelect LoginInputFieldType = "select"
+ LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code"
)
type LoginInputDataField struct {
@@ -189,8 +192,13 @@ type LoginInputDataField struct {
Name string `json:"name"`
// The description of the field shown to the user.
Description string `json:"description"`
+ // A default value that the client can pre-fill the field with.
+ DefaultValue string `json:"default_value,omitempty"`
// A regex pattern that the client can use to validate input client-side.
Pattern string `json:"pattern,omitempty"`
+ // For fields of type select, the valid options.
+ // Pattern may also be filled with a regex that matches the same options.
+ Options []string `json:"options,omitempty"`
// A function that validates the input and optionally cleans it up before it's submitted to the connector.
Validate func(string) (string, error) `json:"-"`
}
@@ -265,6 +273,23 @@ func (f *LoginInputDataField) FillDefaultValidate() {
type LoginUserInputParams struct {
// The fields that the user needs to fill in.
Fields []LoginInputDataField `json:"fields"`
+
+ // Attachments to display alongside the input fields.
+ Attachments []*LoginUserInputAttachment `json:"attachments"`
+}
+
+type LoginUserInputAttachment struct {
+ Type event.MessageType `json:"type,omitempty"`
+ FileName string `json:"filename,omitempty"`
+ Content []byte `json:"content,omitempty"`
+ Info LoginUserInputAttachmentInfo `json:"info,omitempty"`
+}
+
+type LoginUserInputAttachmentInfo struct {
+ MimeType string `json:"mimetype,omitempty"`
+ Width int `json:"w,omitempty"`
+ Height int `json:"h,omitempty"`
+ Size int `json:"size,omitempty"`
}
type LoginCompleteParams struct {
diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go
index 7af2d128..5a2df953 100644
--- a/bridgev2/matrix/connector.go
+++ b/bridgev2/matrix/connector.go
@@ -12,20 +12,21 @@ import (
"encoding/base64"
"errors"
"fmt"
+ "net/http"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
- "unsafe"
- "github.com/gorilla/mux"
_ "github.com/lib/pq"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
_ "go.mau.fi/util/dbutil/litestream"
+ "go.mau.fi/util/exbytes"
"go.mau.fi/util/exsync"
+ "go.mau.fi/util/ptr"
"go.mau.fi/util/random"
"golang.org/x/sync/semaphore"
@@ -80,6 +81,8 @@ type Connector struct {
MediaConfig mautrix.RespMediaConfig
SpecVersions *mautrix.RespVersions
+ SpecCaps *mautrix.RespCapabilities
+ specCapsLock sync.Mutex
Capabilities *bridgev2.MatrixCapabilities
IgnoreUnsupportedServer bool
@@ -101,6 +104,7 @@ type Connector struct {
var (
_ bridgev2.MatrixConnector = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithServer = (*Connector)(nil)
+ _ bridgev2.MatrixConnectorWithArbitraryRoomState = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil)
@@ -140,13 +144,20 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) {
br.EventProcessor.On(event.EventReaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent)
+ br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent)
br.EventProcessor.On(event.StateMember, br.handleRoomEvent)
br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent)
+ br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent)
br.EventProcessor.On(event.StateTopic, br.handleRoomEvent)
+ br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent)
+ br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent)
+ br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent)
+ br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent)
br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent)
br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent)
+ br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent)
br.Bot = br.AS.BotIntent()
br.Crypto = NewCryptoHelper(br)
br.Bridge.Commands.(*commands.Processor).AddHandlers(
@@ -168,6 +179,17 @@ func (br *Connector) Start(ctx context.Context) error {
if err != nil {
return err
}
+ needsStateResync := br.Config.Encryption.Default &&
+ br.Bridge.DB.KV.Get(ctx, database.KeyEncryptionStateResynced) != "true"
+ if needsStateResync {
+ dbExists, err := br.StateStore.TableExists(ctx, "mx_version")
+ if err != nil {
+ return fmt.Errorf("failed to check if mx_version table exists: %w", err)
+ } else if !dbExists {
+ needsStateResync = false
+ br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true")
+ }
+ }
err = br.StateStore.Upgrade(ctx)
if err != nil {
return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err}
@@ -211,17 +233,59 @@ func (br *Connector) Start(ctx context.Context) error {
br.wsStopPinger = make(chan struct{}, 1)
go br.websocketServerPinger()
}
+ if needsStateResync {
+ br.ResyncEncryptionState(ctx)
+ }
return nil
}
+func (br *Connector) ResyncEncryptionState(ctx context.Context) {
+ log := zerolog.Ctx(ctx)
+ roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID])
+ rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, `
+ SELECT rooms.room_id
+ FROM (SELECT DISTINCT(room_id) FROM mx_user_profile WHERE room_id<>'') rooms
+ LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id
+ WHERE mx_room_state.encryption IS NULL
+ `)).AsList()
+ if err != nil {
+ log.Err(err).Msg("Failed to get room list to resync state")
+ return
+ }
+ var failedCount, successCount, forbiddenCount int
+ for _, roomID := range rooms {
+ if roomID == "" {
+ continue
+ }
+ var outContent *event.EncryptionEventContent
+ err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent)
+ if errors.Is(err, mautrix.MForbidden) {
+ // Most likely non-existent room
+ log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room")
+ forbiddenCount++
+ } else if err != nil {
+ log.Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room")
+ failedCount++
+ } else {
+ successCount++
+ }
+ }
+ br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true")
+ log.Info().
+ Int("success_count", successCount).
+ Int("forbidden_count", forbiddenCount).
+ Int("failed_count", failedCount).
+ Msg("Resynced rooms")
+}
+
func (br *Connector) GetPublicAddress() string {
if br.Config.AppService.PublicAddress == "https://bridge.example.com" {
return ""
}
- return br.Config.AppService.PublicAddress
+ return strings.TrimRight(br.Config.AppService.PublicAddress, "/")
}
-func (br *Connector) GetRouter() *mux.Router {
+func (br *Connector) GetRouter() *http.ServeMux {
if br.GetPublicAddress() != "" {
return br.AS.Router
}
@@ -280,16 +344,18 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) {
}
func (br *Connector) ensureConnection(ctx context.Context) {
+ triedToRegister := false
for {
versions, err := br.Bot.Versions(ctx)
if err != nil {
- if errors.Is(err, mautrix.MForbidden) {
+ if errors.Is(err, mautrix.MForbidden) && !triedToRegister {
br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying")
err = br.Bot.EnsureRegistered(ctx)
if err != nil {
br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN")
os.Exit(16)
}
+ triedToRegister = true
} else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) {
br.logInitialRequestError(err, "/versions request failed with auth error")
os.Exit(16)
@@ -302,6 +368,9 @@ func (br *Connector) ensureConnection(ctx context.Context) {
*br.AS.SpecVersions = *versions
br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites)
br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending)
+ br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange)
+ br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) ||
+ (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo)
break
}
}
@@ -346,6 +415,21 @@ func (br *Connector) ensureConnection(ctx context.Context) {
br.Bot.EnsureAppserviceConnection(ctx)
}
+func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities {
+ br.specCapsLock.Lock()
+ defer br.specCapsLock.Unlock()
+ if br.SpecCaps != nil {
+ return br.SpecCaps
+ }
+ caps, err := br.Bot.Capabilities(ctx)
+ if err != nil {
+ br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver")
+ return nil
+ }
+ br.SpecCaps = caps
+ return caps
+}
+
func (br *Connector) fetchMediaConfig(ctx context.Context) {
cfg, err := br.Bot.GetMediaConfig(ctx)
if err != nil {
@@ -411,11 +495,15 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI {
func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error {
if br.Websocket {
br.hasSentAnyStates = true
- return br.AS.SendWebsocket(&appservice.WebsocketRequest{
+ return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
Command: "bridge_status",
Data: state,
})
} else if br.Config.Homeserver.StatusEndpoint != "" {
+ // Connecting states aren't really relevant unless the bridge runs somewhere with an unreliable network
+ if state.StateEvent == status.StateConnecting {
+ return nil
+ }
return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken)
} else {
return nil
@@ -433,7 +521,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
log := zerolog.Ctx(ctx)
if !evt.IsSourceEventDoublePuppeted {
- err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)})
+ err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{ms.ToCheckpoint(evt)})
if err != nil {
log.Err(err).Msg("Failed to send message checkpoint")
}
@@ -450,7 +538,8 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
Msg("Failed to send MSS event")
}
}
- if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
+ if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice &&
+ (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
content := ms.ToNoticeEvent(evt)
if editEvent != "" {
content.SetEdit(editEvent)
@@ -478,11 +567,11 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
return ""
}
-func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error {
+func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*status.MessageCheckpoint) error {
checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints}
if br.Websocket {
- return br.AS.SendWebsocket(&appservice.WebsocketRequest{
+ return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
Command: "message_checkpoint",
Data: checkpointsJSON,
})
@@ -493,7 +582,7 @@ func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpo
return nil
}
- return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken)
+ return checkpointsJSON.SendHTTP(ctx, br.AS.HTTPClient, endpoint, br.AS.Registration.AppToken)
}
func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) {
@@ -533,6 +622,31 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve
return br.Bot.PowerLevels(ctx, roomID)
}
+func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) {
+ if stateKey == "" {
+ switch eventType {
+ case event.StateCreate:
+ createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID)
+ if err != nil || createEvt != nil {
+ return createEvt, err
+ }
+ case event.StateJoinRules:
+ joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID)
+ if err != nil {
+ return nil, err
+ } else if joinRulesContent != nil {
+ return &event.Event{
+ Type: event.StateJoinRules,
+ RoomID: roomID,
+ StateKey: ptr.Ptr(""),
+ Content: event.Content{Parsed: joinRulesContent},
+ }, nil
+ }
+ }
+ }
+ return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey)
+}
+
func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID)
if err != nil {
@@ -573,7 +687,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr
if intent != nil {
intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp)
}
- if evt.Type != event.EventEncrypted {
+ if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction {
err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content)
if err != nil {
return nil, err
@@ -605,7 +719,7 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid.
eventID[1+hashB64Len] = ':'
copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer)
- return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID)))
+ return id.EventID(exbytes.UnsafeString(eventID))
}
func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID {
diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go
index 47226625..7f18f1f5 100644
--- a/bridgev2/matrix/crypto.go
+++ b/bridgev2/matrix/crypto.go
@@ -24,6 +24,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
+ "maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
@@ -37,9 +38,9 @@ func init() {
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
-var NoSessionFound = crypto.NoSessionFound
-var DuplicateMessageIndex = crypto.DuplicateMessageIndex
-var UnknownMessageIndex = olm.UnknownMessageIndex
+var NoSessionFound = crypto.ErrNoSessionFound
+var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
+var UnknownMessageIndex = olm.ErrUnknownMessageIndex
type CryptoHelper struct {
bridge *Connector
@@ -135,7 +136,19 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
return err
}
if isExistingDevice {
- helper.verifyKeysAreOnServer(ctx)
+ if !helper.verifyKeysAreOnServer(ctx) {
+ return nil
+ }
+ } else {
+ err = helper.ShareKeys(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to share device keys: %w", err)
+ }
+ }
+ if helper.bridge.Config.Encryption.SelfSign {
+ if !helper.doSelfSign(ctx) {
+ os.Exit(34)
+ }
}
go helper.resyncEncryptionInfo(context.TODO())
@@ -143,6 +156,46 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
return nil
}
+func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool {
+ log := zerolog.Ctx(ctx)
+ hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx)
+ if err != nil {
+ log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status")
+ return false
+ }
+ log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status")
+ keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey)
+ if !hasKeys || keyInDB == "overwrite" {
+ if keyInDB != "" && keyInDB != "overwrite" {
+ log.WithLevel(zerolog.FatalLevel).
+ Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.")
+ return false
+ }
+ recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx)
+ if recoveryKey != "" {
+ helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey)
+ }
+ if err != nil {
+ log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign")
+ return false
+ }
+ log.Info().Msg("Generated new recovery key and self-signed bot device")
+ } else if !isVerified {
+ if keyInDB == "" {
+ log.WithLevel(zerolog.FatalLevel).
+ Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.")
+ return false
+ }
+ err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB)
+ if err != nil {
+ log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key")
+ return false
+ }
+ log.Info().Msg("Verified bot device with existing recovery key")
+ }
+ return true
+}
+
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
@@ -157,12 +210,12 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
var evt event.EncryptionEventContent
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
if err != nil {
- log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
+ log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event")
_, err = helper.store.DB.Exec(ctx, `
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
`, roomID)
if err != nil {
- log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync")
+ log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync")
}
} else {
maxAge := evt.RotationPeriodMillis
@@ -185,9 +238,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
`, maxAge, maxMessages, roomID)
if err != nil {
- log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table")
+ log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table")
} else {
- log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table")
+ log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table")
}
}
}
@@ -233,7 +286,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
if err != nil {
return nil, false, fmt.Errorf("failed to find existing device ID: %w", err)
} else if len(deviceID) > 0 {
- helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
+ helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database")
}
// Create a new client instance with the default AS settings (including as_token),
// the Login call will then override the access token in the client.
@@ -274,7 +327,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
return client, deviceID != "", nil
}
-func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
+func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
helper.log.Debug().Msg("Making sure keys are still on server")
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
@@ -287,10 +340,11 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
}
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
if ok && len(device.Keys) > 0 {
- return
+ return true
}
helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto")
helper.Reset(ctx, false)
+ return false
}
func (helper *CryptoHelper) Start() {
@@ -385,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
var encrypted *event.EncryptedEventContent
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
- if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
+ if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
return
}
helper.log.Debug().Err(err).
diff --git a/bridgev2/matrix/cryptostore.go b/bridgev2/matrix/cryptostore.go
index 234797a6..4c3b5d30 100644
--- a/bridgev2/matrix/cryptostore.go
+++ b/bridgev2/matrix/cryptostore.go
@@ -45,7 +45,7 @@ func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context,
WHERE room_id=$1
AND (membership='join' OR membership='invite')
AND user_id<>$2
- AND user_id NOT LIKE $3
+ AND user_id NOT LIKE $3 ESCAPE '\'
`, roomID, store.UserID, store.GhostIDFormat)
if err != nil {
return
diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go
index 71c01078..0667981a 100644
--- a/bridgev2/matrix/directmedia.go
+++ b/bridgev2/matrix/directmedia.go
@@ -39,7 +39,7 @@ func (br *Connector) initDirectMedia() error {
if err != nil {
return fmt.Errorf("failed to initialize media proxy: %w", err)
}
- br.MediaProxy.RegisterRoutes(br.AS.Router)
+ br.MediaProxy.RegisterRoutes(br.AS.Router, br.Log.With().Str("component", "media proxy").Logger())
br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed())
dmn.SetUseDirectMedia()
br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access")
diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go
index 2088d5b1..f7254bd4 100644
--- a/bridgev2/matrix/intent.go
+++ b/bridgev2/matrix/intent.go
@@ -9,6 +9,7 @@ package matrix
import (
"bytes"
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -27,6 +28,7 @@ import (
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
"maunium.net/go/mautrix/crypto/attachment"
+ "maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
@@ -43,13 +45,13 @@ type ASIntent struct {
var _ bridgev2.MatrixAPI = (*ASIntent)(nil)
var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil)
+var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil)
func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) {
if extra == nil {
extra = &bridgev2.MatrixSendExtra{}
}
- // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions
- if eventType == event.EventRedaction {
+ if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) {
parsedContent := content.Parsed.(*event.RedactionEventContent)
as.Matrix.AddDoublePuppetValue(content)
return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{
@@ -57,7 +59,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
Extra: content.Raw,
})
}
- if eventType != event.EventReaction && eventType != event.EventRedaction {
+ if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction {
msgContent, ok := content.Parsed.(*event.MessageEventContent)
if ok {
msgContent.AddPerMessageProfileFallback()
@@ -82,16 +84,27 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
eventType = event.EventEncrypted
}
}
- if extra.Timestamp.IsZero() {
- return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content)
- } else {
- return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli())
+ return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()})
+}
+
+func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) {
+ if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
+ return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
}
+ if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
+ return nil, fmt.Errorf("failed to check if room is encrypted: %w", err)
+ } else if encrypted && as.Connector.Crypto != nil {
+ if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil {
+ return nil, err
+ }
+ eventType = event.EventEncrypted
+ }
+ return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID})
}
func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) {
- targetContent := content.Parsed.(*event.MemberEventContent)
- if targetContent.Displayname != "" || targetContent.AvatarURL != "" {
+ targetContent, ok := content.Parsed.(*event.MemberEventContent)
+ if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" {
return
}
memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID)
@@ -126,11 +139,7 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e
if eventType == event.StateMember {
as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content)
}
- if ts.IsZero() {
- resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content)
- } else {
- resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli())
- }
+ resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()})
if err != nil && eventType == event.StateMember {
var httpErr mautrix.HTTPError
if errors.As(err, &httpErr) && httpErr.RespError != nil &&
@@ -412,6 +421,7 @@ func (as *ASIntent) UploadMediaStream(
removeAndClose(replFile)
removeAndClose(tempFile)
}
+ req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
startedAsyncUpload = true
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
@@ -444,6 +454,7 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn
as.Connector.uploadSema.Release(int64(len(req.ContentBytes)))
}
}
+ req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
if resp != nil {
@@ -475,11 +486,62 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr
return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL)
}
-func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
- if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
- return nil
+func dataToFields(data any) (map[string]json.RawMessage, error) {
+ fields, ok := data.(map[string]json.RawMessage)
+ if ok {
+ return fields, nil
}
- return as.Matrix.BeeperUpdateProfile(ctx, data)
+ d, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+ d = canonicaljson.CanonicalJSONAssumeValid(d)
+ err = json.Unmarshal(d, &fields)
+ return fields, err
+}
+
+func marshalField(val any) json.RawMessage {
+ data, _ := json.Marshal(val)
+ if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
+ return canonicaljson.CanonicalJSONAssumeValid(data)
+ }
+ return data
+}
+
+var nullJSON = json.RawMessage("null")
+
+func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
+ if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
+ return as.Matrix.BeeperUpdateProfile(ctx, data)
+ } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo {
+ fields, err := dataToFields(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal fields: %w", err)
+ }
+ currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID)
+ if err != nil {
+ return fmt.Errorf("failed to get current profile: %w", err)
+ }
+ for key, val := range fields {
+ existing, ok := currentProfile.Extra[key]
+ if !ok {
+ if bytes.Equal(val, nullJSON) {
+ continue
+ }
+ err = as.Matrix.SetProfileField(ctx, key, val)
+ } else if !bytes.Equal(marshalField(existing), val) {
+ if bytes.Equal(val, nullJSON) {
+ err = as.Matrix.DeleteProfileField(ctx, key)
+ } else {
+ err = as.Matrix.SetProfileField(ctx, key, val)
+ }
+ }
+ if err != nil {
+ return fmt.Errorf("failed to set profile field %q: %w", key, err)
+ }
+ }
+ }
+ return nil
}
func (as *ASIntent) GetMXID() id.UserID {
@@ -490,8 +552,12 @@ func (as *ASIntent) IsDoublePuppet() bool {
return as.Matrix.IsDoublePuppet()
}
-func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error {
- err := as.Matrix.EnsureJoined(ctx, roomID)
+func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...bridgev2.EnsureJoinedParams) error {
+ var params bridgev2.EnsureJoinedParams
+ if len(extra) > 0 {
+ params = extra[0]
+ }
+ err := as.Matrix.EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{Via: params.Via})
if err != nil {
return err
}
@@ -517,6 +583,39 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent {
return content
}
+func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) {
+ if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
+ // Hungryserv doesn't override the capabilities endpoint nor do room versions
+ return
+ }
+ caps := as.Connector.fetchCapabilities(ctx)
+ roomVer := req.RoomVersion
+ if roomVer == "" && caps != nil && caps.RoomVersions != nil {
+ roomVer = id.RoomVersion(caps.RoomVersions.Default)
+ }
+ if roomVer != "" && !roomVer.PrivilegedRoomCreators() {
+ return
+ }
+ creators, _ := req.CreationContent["additional_creators"].([]id.UserID)
+ creators = append(slices.Clone(creators), as.GetMXID())
+ if req.PowerLevelOverride != nil {
+ for _, creator := range creators {
+ delete(req.PowerLevelOverride.Users, creator)
+ }
+ }
+ for _, evt := range req.InitialState {
+ if evt.Type != event.StatePowerLevels {
+ continue
+ }
+ content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
+ if ok {
+ for _, creator := range creators {
+ delete(content.Users, creator)
+ }
+ }
+ }
+}
+
func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) {
if as.Connector.Config.Encryption.Default {
req.InitialState = append(req.InitialState, &event.Event{
@@ -532,6 +631,7 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom)
}
req.CreationContent["m.federate"] = false
}
+ as.filterCreateRequestForV12(ctx, req)
resp, err := as.Matrix.CreateRoom(ctx, req)
if err != nil {
return "", err
@@ -573,8 +673,19 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id.
}
func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error {
+ if roomID == "" {
+ return nil
+ }
if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) {
- return as.Matrix.BeeperDeleteRoom(ctx, roomID)
+ err := as.Matrix.BeeperDeleteRoom(ctx, roomID)
+ if err != nil {
+ return err
+ }
+ err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal")
+ }
+ return nil
}
members, err := as.Matrix.JoinedMembers(ctx, roomID)
if err != nil {
@@ -662,3 +773,23 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T
})
}
}
+
+func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) {
+ evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID)
+ if err != nil {
+ return nil, err
+ }
+ err = evt.Content.ParseRaw(evt.Type)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content")
+ }
+
+ if evt.Type == event.EventEncrypted {
+ if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt {
+ return nil, errors.New("can't decrypt the event")
+ }
+ return as.Connector.Crypto.Decrypt(ctx, evt)
+ }
+
+ return evt, nil
+}
diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go
index 84e85d24..954d0ad9 100644
--- a/bridgev2/matrix/matrix.go
+++ b/bridgev2/matrix/matrix.go
@@ -27,6 +27,11 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) {
if br.shouldIgnoreEvent(evt) {
return
}
+ if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember {
+ zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events")
+ br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
+ return
+ }
if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require {
zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required")
br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true)
@@ -63,6 +68,10 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event)
case event.EphemeralEventTyping:
typingContent := evt.Content.AsTyping()
typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser)
+ case event.BeeperEphemeralEventAIStream:
+ if br.shouldIgnoreEvent(evt) {
+ return
+ }
}
br.Bridge.QueueMatrixEvent(ctx, evt)
}
@@ -76,6 +85,11 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
+ if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents {
+ log.Debug().Msg("Dropping event from user with no permission to send events")
+ br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
+ return
+ }
ctx = log.WithContext(ctx)
if br.Crypto == nil {
br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true)
@@ -87,17 +101,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
decryptionStart := time.Now()
decrypted, err := br.Crypto.Decrypt(ctx, evt)
decryptionRetryCount := 0
+ var errorEventID id.EventID
if errors.Is(err, NoSessionFound) {
decryptionRetryCount = 1
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
- go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false)
+ go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false)
if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = br.Crypto.Decrypt(ctx, evt)
} else {
- go br.waitLongerForSession(ctx, evt, decryptionStart)
+ go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID)
return
}
}
@@ -106,18 +121,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true)
return
}
- br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart))
+ br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart))
}
-func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) {
+func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) {
log := zerolog.Ctx(ctx)
content := evt.Content.AsEncrypted()
log.Debug().
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, requesting keys and waiting longer...")
+ //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
- var errorEventID *id.EventID
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)
if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@@ -142,7 +157,7 @@ type CommandProcessor interface {
}
func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) {
- err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{{
+ err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{{
RoomID: evt.RoomID,
EventID: evt.ID,
EventType: evt.Type,
@@ -169,7 +184,7 @@ func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool {
}
func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool {
- if br.shouldIgnoreEventFromUser(evt.Sender) {
+ if br.shouldIgnoreEventFromUser(evt.Sender) && evt.Type != event.StateTombstone {
return true
}
dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey]
@@ -220,7 +235,6 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event
go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount)
decrypted.Mautrix.CheckpointSent = true
decrypted.Mautrix.DecryptionDuration = duration
- decrypted.Mautrix.EventSource |= event.SourceDecrypted
br.EventProcessor.Dispatch(ctx, decrypted)
if errorEventID != nil && *errorEventID != "" {
_, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID)
diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go
index 0f6aa68c..f5e438de 100644
--- a/bridgev2/matrix/mxmain/dberror.go
+++ b/bridgev2/matrix/mxmain/dberror.go
@@ -66,7 +66,12 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s
} else if errors.Is(err, dbutil.ErrForeignTables) {
br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info")
} else if errors.Is(err, dbutil.ErrNotOwned) {
- br.Log.Info().Msg("Sharing the same database with different programs is not supported")
+ var noe dbutil.NotOwnedError
+ if errors.As(err, &noe) && noe.Owner == br.Name {
+ br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?")
+ } else {
+ br.Log.Info().Msg("Sharing the same database with different programs is not supported")
+ }
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
br.Log.Info().Msg("Downgrading the bridge is not supported")
}
diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go
new file mode 100644
index 00000000..1b4f1467
--- /dev/null
+++ b/bridgev2/matrix/mxmain/envconfig.go
@@ -0,0 +1,161 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mxmain
+
+import (
+ "fmt"
+ "iter"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+
+ "go.mau.fi/util/random"
+)
+
+var randomParseFilePrefix = random.String(16) + "READFILE:"
+
+func parseEnv(prefix string) iter.Seq2[[]string, string] {
+ return func(yield func([]string, string) bool) {
+ for _, s := range os.Environ() {
+ if !strings.HasPrefix(s, prefix) {
+ continue
+ }
+ kv := strings.SplitN(s, "=", 2)
+ key := strings.TrimPrefix(kv[0], prefix)
+ value := kv[1]
+ if strings.HasSuffix(key, "_FILE") {
+ key = strings.TrimSuffix(key, "_FILE")
+ value = randomParseFilePrefix + value
+ }
+ key = strings.ToLower(key)
+ if !strings.ContainsRune(key, '.') {
+ key = strings.ReplaceAll(key, "__", ".")
+ }
+ if !yield(strings.Split(key, "."), value) {
+ return
+ }
+ }
+ }
+}
+
+func reflectYAMLFieldName(f *reflect.StructField) string {
+ parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2)
+ fieldName := parts[0]
+ if fieldName == "-" && len(parts) == 1 {
+ return ""
+ }
+ if fieldName == "" {
+ return strings.ToLower(f.Name)
+ }
+ return fieldName
+}
+
+type reflectGetResult struct {
+ val reflect.Value
+ valKind reflect.Kind
+ remainingPath []string
+}
+
+func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) {
+ if len(path) == 0 {
+ return &reflectGetResult{val: rv, valKind: rv.Kind()}, true
+ }
+ if rv.Kind() == reflect.Ptr {
+ rv = rv.Elem()
+ }
+ switch rv.Kind() {
+ case reflect.Map:
+ return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true
+ case reflect.Struct:
+ fields := reflect.VisibleFields(rv.Type())
+ for _, field := range fields {
+ fieldName := reflectYAMLFieldName(&field)
+ if fieldName != "" && fieldName == path[0] {
+ return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:])
+ }
+ }
+ default:
+ }
+ return nil, false
+}
+
+func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) {
+ if len(path) > 0 && path[0] == "network" {
+ return reflectGetYAML(network, path[1:])
+ }
+ return reflectGetYAML(main, path)
+}
+
+func formatKeyString(key []string) string {
+ return strings.Join(key, "->")
+}
+
+func UpdateConfigFromEnv(cfg, networkData any, prefix string) error {
+ cfgVal := reflect.ValueOf(cfg)
+ networkVal := reflect.ValueOf(networkData)
+ for key, value := range parseEnv(prefix) {
+ field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key)
+ if !ok {
+ return fmt.Errorf("%s not found", formatKeyString(key))
+ }
+ if strings.HasPrefix(value, randomParseFilePrefix) {
+ filepath := strings.TrimPrefix(value, randomParseFilePrefix)
+ fileData, err := os.ReadFile(filepath)
+ if err != nil {
+ return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err)
+ }
+ value = strings.TrimSpace(string(fileData))
+ }
+ var parsedVal any
+ var err error
+ switch field.valKind {
+ case reflect.String:
+ parsedVal = value
+ case reflect.Bool:
+ parsedVal, err = strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ parsedVal, err = strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ parsedVal, err = strconv.ParseUint(value, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ case reflect.Float32, reflect.Float64:
+ parsedVal, err = strconv.ParseFloat(value, 64)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ default:
+ return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key))
+ }
+ if field.val.Kind() == reflect.Ptr {
+ if field.val.IsNil() {
+ field.val.Set(reflect.New(field.val.Type().Elem()))
+ }
+ field.val = field.val.Elem()
+ }
+ if field.val.Kind() == reflect.Map {
+ key = key[:len(key)-len(field.remainingPath)]
+ mapKeyStr := strings.Join(field.remainingPath, ".")
+ key = append(key, mapKeyStr)
+ if field.val.Type().Key().Kind() != reflect.String {
+ return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key))
+ }
+ field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal))
+ } else {
+ field.val.Set(reflect.ValueOf(parsedVal))
+ }
+ }
+ return nil
+}
diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml
index 48e0d528..ccc81c4b 100644
--- a/bridgev2/matrix/mxmain/example-config.yaml
+++ b/bridgev2/matrix/mxmain/example-config.yaml
@@ -15,6 +15,7 @@ bridge:
# By default, users who are in the same group on the remote network will be
# in the same Matrix room bridged to that group. If this is set to true,
# every user will get their own Matrix room instead.
+ # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST.
split_portals: false
# Should the bridge resend `m.bridge` events to all portals on startup?
resend_bridge_info: false
@@ -28,6 +29,9 @@ bridge:
# How long after an unknown error should the bridge attempt a full reconnect?
# Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value.
unknown_error_auto_reconnect: null
+ # Maximum number of times to do the auto-reconnect above.
+ # The counter is per login, but is never reset except on logout and restart.
+ unknown_error_max_auto_reconnects: 10
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
bridge_matrix_leave: false
@@ -46,6 +50,11 @@ bridge:
# Should cross-room reply metadata be bridged?
# Most Matrix clients don't support this and servers may reject such messages too.
cross_room_replies: false
+ # If a state event fails to bridge, should the bridge revert any state changes made by that event?
+ revert_failed_state_changes: false
+ # In portals with no relay set, should Matrix users be kicked if they're
+ # not logged into an account that's in the remote chat?
+ kick_matrix_users: true
# What should be done to portal rooms when a user logs out or is logged out?
# Permitted values:
@@ -235,6 +244,9 @@ matrix:
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
# rather than keeping the whole file in memory.
upload_file_threshold: 5242880
+ # Should the bridge set additional custom profile info for ghosts?
+ # This can make a lot of requests, as there's no batch profile update endpoint.
+ ghost_extra_profile_info: false
# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors.
analytics:
@@ -247,10 +259,8 @@ analytics:
# Settings for provisioning API
provisioning:
- # Prefix for the provisioning API paths.
- prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate" or null, a random secret will be generated,
- # or if set to "disable", the provisioning API will be disabled.
+ # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters.
shared_secret: generate
# Whether to allow provisioning API requests to be authed using Matrix access tokens.
# This follows the same rules as double puppeting to determine which server to contact to check the token,
@@ -276,6 +286,14 @@ public_media:
expiry: 0
# Length of hash to use for public media URLs. Must be between 0 and 32.
hash_length: 32
+ # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served.
+ # If you change this, you must configure your reverse proxy to rewrite the path accordingly.
+ path_prefix: /_mautrix/publicmedia
+ # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs?
+ # If false, the generated URLs will just have the MXC URI and a HMAC signature.
+ # The hash_length field will be used to decide the length of the generated URL.
+ # This also allows invalidating URLs by deleting the database entry.
+ use_database: false
# Settings for converting remote media to custom mxc:// URIs instead of reuploading.
# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html
@@ -366,6 +384,12 @@ encryption:
# Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861).
# Changing this option requires updating the appservice registration file.
msc4190: false
+ # Whether to encrypt reactions and reply metadata as per MSC4392.
+ msc4392: false
+ # Should the bridge bot generate a recovery key and cross-signing keys and verify itself?
+ # Note that without the latest version of MSC4190, this will fail if you reset the bridge database.
+ # The generated recovery key will be saved in the kv_store table under `recovery_key`.
+ self_sign: false
# Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
# You must use a client that supports requesting keys from other users to use this feature.
allow_key_sharing: true
@@ -428,6 +452,16 @@ encryption:
# You should not enable this option unless you understand all the implications.
disable_device_change_key_rotation: false
+# Prefix for environment variables. All variables with this prefix must map to valid config fields.
+# Nesting in variable names is represented with a dot (.).
+# If there are no dots in the name, two underscores (__) are replaced with a dot.
+#
+# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token.
+# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily.
+#
+# If this is null, reading config fields from environment will be disabled.
+env_config_prefix: null
+
# Logging config. See https://github.com/tulir/zeroconfig for details.
logging:
min_level: debug
diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go
index c8eb820b..97cdeddf 100644
--- a/bridgev2/matrix/mxmain/legacymigrate.go
+++ b/bridgev2/matrix/mxmain/legacymigrate.go
@@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB(
}
var dbVersion int
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
- if dbVersion < expectedVersion {
+ if err != nil {
+ log.Fatal().Err(err).Msg("Failed to get database version")
+ return
+ } else if dbVersion < expectedVersion {
log.Fatal().
Int("expected_version", expectedVersion).
Int("version", dbVersion).
diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go
index e6219c50..1e8b51d1 100644
--- a/bridgev2/matrix/mxmain/main.go
+++ b/bridgev2/matrix/mxmain/main.go
@@ -26,6 +26,7 @@ import (
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
+ "go.mau.fi/util/progver"
"gopkg.in/yaml.v3"
flag "maunium.net/go/mauflag"
@@ -62,6 +63,9 @@ type BridgeMain struct {
// git tag to see if the built version is the release or a dev build.
// You can either bump this right after a release or right before, as long as it matches on the release commit.
Version string
+ // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning,
+ // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch.
+ SemCalVer bool
// PostInit is a function that will be called after the bridge has been initialized but before it is started.
PostInit func()
@@ -86,11 +90,7 @@ type BridgeMain struct {
RegistrationPath string
SaveConfig bool
- baseVersion string
- commit string
- LinkifiedVersion string
- VersionDesc string
- BuildTime time.Time
+ ver progver.ProgramVersion
AdditionalShortFlags string
AdditionalLongFlags string
@@ -99,14 +99,7 @@ type BridgeMain struct {
}
type VersionJSONOutput struct {
- Name string
- URL string
-
- Version string
- IsRelease bool
- Commit string
- FormattedVersion string
- BuildTime time.Time
+ progver.ProgramVersion
OS string
Arch string
@@ -147,18 +140,11 @@ func (br *BridgeMain) PreInit() {
flag.PrintHelp()
os.Exit(0)
} else if *version {
- fmt.Println(br.VersionDesc)
+ fmt.Println(br.ver.VersionDescription)
os.Exit(0)
} else if *versionJSON {
output := VersionJSONOutput{
- URL: br.URL,
- Name: br.Name,
-
- Version: br.baseVersion,
- IsRelease: br.Version == br.baseVersion,
- Commit: br.commit,
- FormattedVersion: br.Version,
- BuildTime: br.BuildTime,
+ ProgramVersion: br.ver,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
@@ -240,8 +226,8 @@ func (br *BridgeMain) Init() {
br.Log.Info().
Str("name", br.Name).
- Str("version", br.Version).
- Time("built_at", br.BuildTime).
+ Str("version", br.ver.FormattedVersion).
+ Time("built_at", br.ver.BuildTime).
Str("go_version", runtime.Version()).
Msg("Initializing bridge")
@@ -255,7 +241,7 @@ func (br *BridgeMain) Init() {
br.Matrix.AS.DoublePuppetValue = br.Name
br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{
Func: func(ce *commands.Event) {
- ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123))
+ ce.Reply(br.ver.MarkdownDescription())
},
Name: "version",
Help: commands.HelpMeta{
@@ -368,6 +354,13 @@ func (br *BridgeMain) LoadConfig() {
}
}
cfg.Bridge.Backfill = cfg.Backfill
+ if cfg.EnvConfigPrefix != "" {
+ err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix)
+ if err != nil {
+ _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err)
+ os.Exit(10)
+ }
+ }
br.Config = &cfg
}
@@ -446,42 +439,12 @@ func (br *BridgeMain) Stop() {
//
// (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`)
func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) {
- br.baseVersion = br.Version
- if len(tag) > 0 && tag[0] == 'v' {
- tag = tag[1:]
- }
- if tag != br.Version {
- suffix := ""
- if !strings.HasSuffix(br.Version, "+dev") {
- suffix = "+dev"
- }
- if len(commit) > 8 {
- br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8])
- } else {
- br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix)
- }
- }
-
- br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version)
- if tag == br.Version {
- br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag)
- } else if len(commit) > 8 {
- br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1)
- }
- var buildTime time.Time
- if rawBuildTime != "unknown" {
- buildTime, _ = time.Parse(time.RFC3339, rawBuildTime)
- }
- var builtWith string
- if buildTime.IsZero() {
- rawBuildTime = "unknown"
- builtWith = runtime.Version()
- } else {
- rawBuildTime = buildTime.Format(time.RFC1123)
- builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version())
- }
- mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent)
- br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith)
- br.commit = commit
- br.BuildTime = buildTime
+ br.ver = progver.ProgramVersion{
+ Name: br.Name,
+ URL: br.URL,
+ BaseVersion: br.Version,
+ SemCalVer: br.SemCalVer,
+ }.Init(tag, commit, rawBuildTime)
+ mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent)
+ br.Version = br.ver.FormattedVersion
}
diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go
index f865a19e..243b91da 100644
--- a/bridgev2/matrix/provisioning.go
+++ b/bridgev2/matrix/provisioning.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -17,18 +17,20 @@ import (
"sync"
"time"
- "github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exstrings"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/bridgev2/provisionutil"
"maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/federation"
"maunium.net/go/mautrix/id"
@@ -40,7 +42,7 @@ type matrixAuthCacheEntry struct {
}
type ProvisioningAPI struct {
- Router *mux.Router
+ Router *http.ServeMux
br *Connector
log zerolog.Logger
@@ -83,24 +85,18 @@ const (
provisioningUserKey provisioningContextKey = iota
provisioningUserLoginKey
provisioningLoginProcessKey
+ ProvisioningKeyRequest
)
-const ProvisioningKeyRequest = "fi.mau.provision.request"
-
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}
-func (prov *ProvisioningAPI) GetRouter() *mux.Router {
+func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
return prov.Router
}
-type IProvisioningAPI interface {
- GetRouter() *mux.Router
- GetUser(r *http.Request) *bridgev2.User
-}
-
-func (br *Connector) GetProvisioning() IProvisioningAPI {
+func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI {
return br.Provisioning
}
@@ -116,41 +112,57 @@ func (prov *ProvisioningAPI) Init() {
tp.Dialer.Timeout = 10 * time.Second
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
- prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
- prov.Router.Use(hlog.NewHandler(prov.log))
- prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
- prov.Router.Use(exhttp.CORSMiddleware)
- prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}))
- prov.Router.Use(prov.AuthMiddleware)
- prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
- prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
- prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
- prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
- prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
- prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
- prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
- prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
- prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers)
- prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
- prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
- prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
+ prov.Router = http.NewServeMux()
+ prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami)
+ prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities)
+ prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows)
+ prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
+ prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep)
+ prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
+ prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins)
+ prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList)
+ prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
+ prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
+ prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
+ prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup)
if prov.br.Config.Provisioning.EnableSessionTransfers {
prov.log.Debug().Msg("Enabling session transfer API")
- prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer)
- prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer)
+ prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer)
+ prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer)
}
if prov.br.Config.Provisioning.DebugEndpoints {
prov.log.Debug().Msg("Enabling debug API at /debug")
- r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
- r.Use(prov.DebugAuthMiddleware)
- r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet)
- r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet)
- r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet)
- r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet)
- r.PathPrefix("/pprof/").HandlerFunc(pprof.Index)
+ debugRouter := http.NewServeMux()
+ debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline)
+ debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile)
+ debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol)
+ debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace)
+ debugRouter.HandleFunc("/pprof/", pprof.Index)
+ prov.br.AS.Router.Handle("/debug/", exhttp.ApplyMiddleware(
+ debugRouter,
+ exhttp.StripPrefix("/debug"),
+ hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ prov.DebugAuthMiddleware,
+ ))
}
+
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware(
+ prov.Router,
+ exhttp.StripPrefix("/_matrix/provision"),
+ hlog.NewHandler(prov.log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ exhttp.CORSMiddleware,
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ prov.AuthMiddleware,
+ ))
}
func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error {
@@ -194,12 +206,20 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI
}
}
+func disabledAuth(w http.ResponseWriter, r *http.Request) {
+ mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w)
+}
+
func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
+ secret := prov.br.Config.Provisioning.SharedSecret
+ if len(secret) < 16 {
+ return http.HandlerFunc(disabledAuth)
+ }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" {
mautrix.MMissingToken.WithMessage("Missing auth token").Write(w)
- } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
+ } else if !exstrings.ConstantTimeEqual(auth, secret) {
mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w)
} else {
h.ServeHTTP(w, r)
@@ -208,6 +228,10 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
+ secret := prov.br.Config.Provisioning.SharedSecret
+ if len(secret) < 16 {
+ return http.HandlerFunc(disabledAuth)
+ }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" && prov.GetAuthFromRequest != nil {
@@ -221,7 +245,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
if userID == "" && prov.GetUserIDFromRequest != nil {
userID = prov.GetUserIDFromRequest(r)
}
- if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
+ if !exstrings.ConstantTimeEqual(auth, secret) {
var err error
if strings.HasPrefix(auth, "openid:") {
err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:"))
@@ -250,38 +274,6 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r)
ctx = context.WithValue(ctx, provisioningUserKey, user)
- if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
- prov.loginsLock.RLock()
- login, ok := prov.logins[loginID]
- prov.loginsLock.RUnlock()
- if !ok {
- zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found")
- mautrix.MNotFound.WithMessage("Login not found").Write(w)
- return
- }
- login.Lock.Lock()
- // This will only unlock after the handler runs
- defer login.Lock.Unlock()
- stepID := mux.Vars(r)["stepID"]
- if login.NextStep.StepID != stepID {
- zerolog.Ctx(r.Context()).Warn().
- Str("request_step_id", stepID).
- Str("expected_step_id", login.NextStep.StepID).
- Msg("Step ID does not match")
- mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
- return
- }
- stepType := mux.Vars(r)["stepType"]
- if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
- zerolog.Ctx(r.Context()).Warn().
- Str("request_step_type", stepType).
- Str("expected_step_type", string(login.NextStep.Type)).
- Msg("Step type does not match")
- mautrix.MBadState.WithMessage("Step type does not match").Write(w)
- return
- }
- ctx = context.WithValue(ctx, provisioningLoginProcessKey, login)
- }
h.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -332,7 +324,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) {
prevState.UserID = ""
prevState.RemoteID = ""
prevState.RemoteName = ""
- prevState.RemoteProfile = nil
+ prevState.RemoteProfile = status.RemoteProfile{}
resp.Logins[i] = RespWhoamiLogin{
StateEvent: prevState.StateEvent,
StateTS: prevState.Timestamp,
@@ -364,18 +356,24 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques
})
}
+func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) {
+ exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning)
+}
+
var ErrNilStep = errors.New("bridge returned nil step with no error")
+var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"}
func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) {
overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r)
if failed {
return
}
- login, err := prov.net.CreateLogin(
- r.Context(),
- prov.GetUser(r),
- mux.Vars(r)["flowID"],
- )
+ user := prov.GetUser(r)
+ if overrideLogin == nil && user.HasTooManyLogins() {
+ ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w)
+ return
+ }
+ login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID"))
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
RespondWithError(w, err, "Internal error creating login process")
@@ -405,10 +403,18 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
Override: overrideLogin,
}
prov.loginsLock.Unlock()
+ zerolog.Ctx(r.Context()).Info().
+ Any("first_step", firstStep).
+ Msg("Created login process")
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
}
func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) {
+ zerolog.Ctx(ctx).Info().
+ Str("step_id", step.StepID).
+ Str("user_login_id", string(step.CompleteParams.UserLoginID)).
+ Msg("Login completed successfully")
+ prov.deleteLogin(login, false)
if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID {
return
}
@@ -422,6 +428,61 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
}, bridgev2.DeleteOpts{LogoutRemote: true})
}
+func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) {
+ if cancel {
+ login.Process.Cancel()
+ }
+ prov.loginsLock.Lock()
+ delete(prov.logins, login.ID)
+ prov.loginsLock.Unlock()
+}
+
+func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) {
+ loginID := r.PathValue("loginProcessID")
+ prov.loginsLock.RLock()
+ login, ok := prov.logins[loginID]
+ prov.loginsLock.RUnlock()
+ if !ok {
+ zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found")
+ mautrix.MNotFound.WithMessage("Login not found").Write(w)
+ return
+ }
+ login.Lock.Lock()
+ // This will only unlock after the handler runs
+ defer login.Lock.Unlock()
+ stepID := r.PathValue("stepID")
+ if login.NextStep.StepID != stepID {
+ zerolog.Ctx(r.Context()).Warn().
+ Str("request_step_id", stepID).
+ Str("expected_step_id", login.NextStep.StepID).
+ Msg("Step ID does not match")
+ mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
+ return
+ }
+ stepType := r.PathValue("stepType")
+ if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
+ zerolog.Ctx(r.Context()).Warn().
+ Str("request_step_type", stepType).
+ Str("expected_step_type", string(login.NextStep.Type)).
+ Msg("Step type does not match")
+ mautrix.MBadState.WithMessage("Step type does not match").Write(w)
+ return
+ }
+ ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login)
+ r = r.WithContext(ctx)
+ switch bridgev2.LoginStepType(r.PathValue("stepType")) {
+ case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies:
+ prov.PostLoginSubmitInput(w, r)
+ case bridgev2.LoginStepTypeDisplayAndWait:
+ prov.PostLoginWait(w, r)
+ case bridgev2.LoginStepTypeComplete:
+ fallthrough
+ default:
+ // This is probably impossible because of the above check that the next step type matches the request.
+ mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w)
+ }
+}
+
func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) {
var params map[string]string
err := json.NewDecoder(r.Body).Decode(¶ms)
@@ -446,11 +507,14 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input")
RespondWithError(w, err, "Internal error submitting input")
+ prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
+ } else {
+ zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
@@ -464,18 +528,21 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait")
RespondWithError(w, err, "Internal error waiting for login")
+ prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
+ } else {
+ zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
user := prov.GetUser(r)
- userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
+ userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
if userLoginID == "all" {
for {
login := user.GetDefaultLogin()
@@ -552,115 +619,23 @@ func RespondWithError(w http.ResponseWriter, err error, message string) {
}
}
-type RespResolveIdentifier struct {
- ID networkid.UserID `json:"id"`
- Name string `json:"name,omitempty"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Identifiers []string `json:"identifiers,omitempty"`
- MXID id.UserID `json:"mxid,omitempty"`
- DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
-}
-
func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) {
login := prov.GetLoginForRequest(w, r)
if login == nil {
return
}
- api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
- if !ok {
- mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
- return
- }
- resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
+ resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat)
if err != nil {
- zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
RespondWithError(w, err, "Internal error resolving identifier")
- return
} else if resp == nil {
mautrix.MNotFound.WithMessage("Identifier not found").Write(w)
- return
- }
- apiResp := &RespResolveIdentifier{
- ID: resp.UserID,
- }
- status := http.StatusOK
- if resp.Ghost != nil {
- if resp.UserInfo != nil {
- resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo)
- }
- apiResp.Name = resp.Ghost.Name
- apiResp.AvatarURL = resp.Ghost.AvatarMXC
- apiResp.Identifiers = resp.Ghost.Identifiers
- apiResp.MXID = resp.Ghost.Intent.GetMXID()
- } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
- apiResp.Name = *resp.UserInfo.Name
- }
- if resp.Chat != nil {
- if resp.Chat.Portal == nil {
- resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey)
- if err != nil {
- zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal")
- mautrix.MUnknown.WithMessage("Failed to get portal").Write(w)
- return
- }
- }
- if createChat && resp.Chat.Portal.MXID == "" {
+ } else {
+ status := http.StatusOK
+ if resp.JustCreated {
status = http.StatusCreated
- err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo)
- if err != nil {
- zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room")
- mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w)
- return
- }
}
- apiResp.DMRoomID = resp.Chat.Portal.MXID
+ exhttp.WriteJSONResponse(w, status, resp)
}
- exhttp.WriteJSONResponse(w, status, apiResp)
-}
-
-type RespGetContactList struct {
- Contacts []*RespResolveIdentifier `json:"contacts"`
-}
-
-func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) {
- apiResp = make([]*RespResolveIdentifier, len(resp))
- for i, contact := range resp {
- apiContact := &RespResolveIdentifier{
- ID: contact.UserID,
- }
- apiResp[i] = apiContact
- if contact.UserInfo != nil {
- if contact.UserInfo.Name != nil {
- apiContact.Name = *contact.UserInfo.Name
- }
- if contact.UserInfo.Identifiers != nil {
- apiContact.Identifiers = contact.UserInfo.Identifiers
- }
- }
- if contact.Ghost != nil {
- if contact.Ghost.Name != "" {
- apiContact.Name = contact.Ghost.Name
- }
- if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
- apiContact.Identifiers = contact.Ghost.Identifiers
- }
- apiContact.AvatarURL = contact.Ghost.AvatarMXC
- apiContact.MXID = contact.Ghost.Intent.GetMXID()
- }
- if contact.Chat != nil {
- if contact.Chat.Portal == nil {
- var err error
- contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
- }
- }
- if contact.Chat.Portal != nil {
- apiContact.DMRoomID = contact.Chat.Portal.MXID
- }
- }
- }
- return
}
func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) {
@@ -668,30 +643,18 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque
if login == nil {
return
}
- api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
- if !ok {
- mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w)
- return
- }
- resp, err := api.GetContactList(r.Context())
+ resp, err := provisionutil.GetContactList(r.Context(), login)
if err != nil {
- zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
- RespondWithError(w, err, "Internal error fetching contact list")
+ RespondWithError(w, err, "Internal error getting contact list")
return
}
- exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{
- Contacts: prov.processResolveIdentifiers(r.Context(), resp),
- })
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
type ReqSearchUsers struct {
Query string `json:"query"`
}
-type RespSearchUsers struct {
- Results []*RespResolveIdentifier `json:"results"`
-}
-
func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) {
var req ReqSearchUsers
err := json.NewDecoder(r.Body).Decode(&req)
@@ -704,20 +667,12 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
if login == nil {
return
}
- api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
- if !ok {
- mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w)
- return
- }
- resp, err := api.SearchUsers(r.Context(), req.Query)
+ resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query)
if err != nil {
- zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
- RespondWithError(w, err, "Internal error fetching contact list")
+ RespondWithError(w, err, "Internal error searching users")
return
}
- exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{
- Results: prov.processResolveIdentifiers(r.Context(), resp),
- })
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) {
@@ -729,11 +684,24 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request
}
func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) {
+ var req bridgev2.GroupCreateParams
+ err := json.NewDecoder(r.Body).Decode(&req)
+ if err != nil {
+ zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
+ mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
+ return
+ }
+ req.Type = r.PathValue("type")
login := prov.GetLoginForRequest(w, r)
if login == nil {
return
}
- mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w)
+ resp, err := provisionutil.CreateGroup(r.Context(), login, &req)
+ if err != nil {
+ RespondWithError(w, err, "Internal error creating group")
+ return
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, resp)
}
type ReqExportCredentials struct {
diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml
index b9879ea5..26068db4 100644
--- a/bridgev2/matrix/provisioning.yaml
+++ b/bridgev2/matrix/provisioning.yaml
@@ -361,14 +361,25 @@ paths:
$ref: '#/components/responses/InternalError'
501:
$ref: '#/components/responses/NotSupported'
- /v3/create_group:
+ /v3/create_group/{type}:
post:
tags: [ snc ]
summary: Create a group chat on the remote network.
operationId: createGroup
parameters:
- $ref: "#/components/parameters/loginID"
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/GroupCreateParams'
responses:
+ 200:
+ description: Identifier resolved successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/CreatedGroup'
401:
$ref: '#/components/responses/Unauthorized'
404:
@@ -389,7 +400,7 @@ components:
- username
- meow@example.com
loginID:
- name: loginID
+ name: login_id
in: query
description: An optional explicit login ID to do the action through.
required: false
@@ -572,6 +583,74 @@ components:
description: The Matrix room ID of the direct chat with the user.
examples:
- '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
+ GroupCreateParams:
+ type: object
+ description: |
+ Parameters for creating a group chat.
+ The /capabilities endpoint response must be checked to see which fields are actually allowed.
+ properties:
+ type:
+ type: string
+ description: The type of group to create.
+ examples:
+ - channel
+ username:
+ type: string
+ description: The public username for the created group.
+ participants:
+ type: array
+ description: The users to add to the group initially.
+ items:
+ type: string
+ parent:
+ type: object
+ name:
+ type: object
+ description: The `m.room.name` event content for the room.
+ properties:
+ name:
+ type: string
+ avatar:
+ type: object
+ description: The `m.room.avatar` event content for the room.
+ properties:
+ url:
+ type: string
+ format: mxc
+ topic:
+ type: object
+ description: The `m.room.topic` event content for the room.
+ properties:
+ topic:
+ type: string
+ disappear:
+ type: object
+ description: The `com.beeper.disappearing_timer` event content for the room.
+ properties:
+ type:
+ type: string
+ timer:
+ type: number
+ room_id:
+ type: string
+ format: matrix_room_id
+ description: |
+ An existing Matrix room ID to bridge to.
+ The other parameters must be already in sync with the room state when using this parameter.
+ CreatedGroup:
+ type: object
+ description: A successfully created group chat.
+ required: [id, mxid]
+ properties:
+ id:
+ type: string
+ description: The internal chat ID of the created group.
+ mxid:
+ type: string
+ format: matrix_room_id
+ description: The Matrix room ID of the portal.
+ examples:
+ - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
LoginStep:
type: object
description: A step in a login process.
@@ -635,7 +714,7 @@ components:
type:
type: string
description: The type of field.
- enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ]
+ enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ]
id:
type: string
description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge.
@@ -649,10 +728,53 @@ components:
description: A more detailed description of the field shown to the user.
examples:
- Include the country code with a +
+ default_value:
+ type: string
+ description: A default value that the client can pre-fill the field with.
pattern:
type: string
format: regex
description: A regular expression that the field value must match.
+ options:
+ type: array
+ description: For fields of type select, the valid options.
+ items:
+ type: string
+ attachments:
+ type: array
+ description: A list of media attachments to show the user alongside the form fields.
+ items:
+ type: object
+ description: A media attachment to show the user.
+ required: [ type, filename, content ]
+ properties:
+ type:
+ type: string
+ description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported.
+ enum: [ m.image, m.audio ]
+ filename:
+ type: string
+ description: The filename for the media attachment.
+ content:
+ type: string
+ description: The raw file content for the attachment encoded in base64.
+ info:
+ type: object
+ description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted.
+ properties:
+ mimetype:
+ type: string
+ description: The MIME type for the media content.
+ examples: [ image/png, audio/mpeg ]
+ w:
+ type: number
+ description: The width of the media in pixels. Only applicable for images and videos.
+ h:
+ type: number
+ description: The height of the media in pixels. Only applicable for images and videos.
+ size:
+ type: number
+ description: The size of the media content in number of bytes. Strongly recommended to include.
- description: Cookie login step
required: [ type, cookies ]
properties:
diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go
index 9db5f442..82ea8c2b 100644
--- a/bridgev2/matrix/publicmedia.go
+++ b/bridgev2/matrix/publicmedia.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,18 +7,26 @@
package matrix
import (
+ "context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
+ "mime"
"net/http"
+ "net/url"
+ "slices"
+ "strings"
"time"
- "github.com/gorilla/mux"
+ "github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
+ "maunium.net/go/mautrix/bridgev2/database"
+ "maunium.net/go/mautrix/crypto/attachment"
+ "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -35,7 +43,10 @@ func (br *Connector) initPublicMedia() error {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
- br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
+ br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia)
+ br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia)
+ br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
+ br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia)
return nil
}
@@ -46,6 +57,20 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte {
return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)]
}
+func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte {
+ hasher := hmac.New(sha256.New, br.pubMediaSigKey)
+ hasher.Write([]byte(pm.MXC.String()))
+ hasher.Write([]byte(pm.MimeType))
+ if pm.Keys != nil {
+ hasher.Write([]byte(pm.Keys.Version))
+ hasher.Write([]byte(pm.Keys.Key.Algorithm))
+ hasher.Write([]byte(pm.Keys.Key.Key))
+ hasher.Write([]byte(pm.Keys.InitVector))
+ hasher.Write([]byte(pm.Keys.Hashes.SHA256))
+ }
+ return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength]
+}
+
func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte {
var expiresAt []byte
if br.Config.PublicMedia.Expiry > 0 {
@@ -76,16 +101,15 @@ var proxyHeadersToCopy = []string{
}
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
contentURI := id.ContentURI{
- Homeserver: vars["server"],
- FileID: vars["mediaID"],
+ Homeserver: r.PathValue("server"),
+ FileID: r.PathValue("mediaID"),
}
if !contentURI.IsValid() {
http.Error(w, "invalid content URI", http.StatusBadRequest)
return
}
- checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
+ checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
return
@@ -96,9 +120,47 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
http.Error(w, "checksum expired", http.StatusGone)
return
}
+ br.doProxyMedia(w, r, contentURI, nil, "")
+}
+
+func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) {
+ if !br.Config.PublicMedia.UseDatabase {
+ http.Error(w, "public media short links are disabled", http.StatusNotFound)
+ return
+ }
+ log := zerolog.Ctx(r.Context())
+ media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID"))
+ if err != nil {
+ log.Err(err).Msg("Failed to get public media from database")
+ http.Error(w, "failed to get media metadata", http.StatusInternalServerError)
+ return
+ } else if media == nil {
+ http.Error(w, "media ID not found", http.StatusNotFound)
+ return
+ } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) {
+ // This is not gone as it can still be refreshed in the DB
+ http.Error(w, "media expired", http.StatusNotFound)
+ return
+ } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil {
+ http.Error(w, "media keys are malformed", http.StatusInternalServerError)
+ return
+ }
+ br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType)
+}
+
+var safeMimes = []string{
+ "text/css", "text/plain", "text/csv",
+ "application/json", "application/ld+json",
+ "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
+ "video/mp4", "video/webm", "video/ogg", "video/quicktime",
+ "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
+ "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac",
+}
+
+func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) {
resp, err := br.Bot.Download(r.Context(), contentURI)
if err != nil {
- br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
+ zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
http.Error(w, "failed to download media", http.StatusInternalServerError)
return
}
@@ -106,11 +168,41 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
for _, hdr := range proxyHeadersToCopy {
w.Header()[hdr] = resp.Header[hdr]
}
+ stream := resp.Body
+ if encInfo != nil {
+ if mimeType == "" {
+ mimeType = "application/octet-stream"
+ }
+ contentDisposition := "attachment"
+ if slices.Contains(safeMimes, mimeType) {
+ contentDisposition = "inline"
+ }
+ dispositionArgs := map[string]string{}
+ if filename := r.PathValue("filename"); filename != "" {
+ dispositionArgs["filename"] = filename
+ }
+ w.Header().Set("Content-Type", mimeType)
+ w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs))
+ // Note: this won't check the Close result like it should, but it's probably not a big deal here
+ stream = encInfo.DecryptStream(stream)
+ } else if filename := r.PathValue("filename"); filename != "" {
+ contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
+ if contentDisposition == "" {
+ contentDisposition = "attachment"
+ }
+ w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{
+ "filename": filename,
+ }))
+ }
w.WriteHeader(http.StatusOK)
- _, _ = io.Copy(w, resp.Body)
+ _, _ = io.Copy(w, stream)
}
func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string {
+ return br.getPublicMediaAddressWithFileName(contentURI, "")
+}
+
+func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string {
if br.pubMediaSigKey == nil {
return ""
}
@@ -118,11 +210,69 @@ func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) strin
if err != nil || !parsed.IsValid() {
return ""
}
- return fmt.Sprintf(
- "%s/_mautrix/publicmedia/%s/%s/%s",
+ fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_"))
+ if fileName == ".." {
+ fileName = ""
+ }
+ parts := []string{
br.GetPublicAddress(),
+ strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
parsed.Homeserver,
parsed.FileID,
base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)),
- )
+ fileName,
+ }
+ if fileName == "" {
+ parts = parts[:len(parts)-1]
+ }
+ return strings.Join(parts, "/")
+}
+
+func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) {
+ if br.pubMediaSigKey == nil {
+ return "", bridgev2.ErrPublicMediaDisabled
+ }
+ if !br.Config.PublicMedia.UseDatabase {
+ if evt.File != nil {
+ return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled)
+ }
+ return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil
+ }
+ mxc := evt.URL
+ var keys *attachment.EncryptedFile
+ if evt.File != nil {
+ mxc = evt.File.URL
+ keys = &evt.File.EncryptedFile
+ }
+ parsedMXC, err := mxc.Parse()
+ if err != nil {
+ return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
+ }
+ pm := &database.PublicMedia{
+ MXC: parsedMXC,
+ Keys: keys,
+ MimeType: evt.GetInfo().MimeType,
+ }
+ if br.Config.PublicMedia.Expiry > 0 {
+ pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second)
+ }
+ pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm))
+ err = br.Bridge.DB.PublicMedia.Put(ctx, pm)
+ if err != nil {
+ return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
+ }
+ fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_"))
+ if fileName == ".." {
+ fileName = ""
+ }
+ parts := []string{
+ br.GetPublicAddress(),
+ strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
+ pm.PublicID,
+ fileName,
+ }
+ if fileName == "" {
+ parts = parts[:len(parts)-1]
+ }
+ return strings.Join(parts, "/"), nil
}
diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go
index c679f960..b498cacd 100644
--- a/bridgev2/matrix/websocket.go
+++ b/bridgev2/matrix/websocket.go
@@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) {
addr = br.Config.Homeserver.Address
}
for {
- err := br.AS.StartWebsocket(addr, onConnect)
+ err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect)
if errors.Is(err, appservice.ErrWebsocketManualStop) {
return
} else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced {
diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go
index ae1b99d7..be26db49 100644
--- a/bridgev2/matrixinterface.go
+++ b/bridgev2/matrixinterface.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -10,10 +10,11 @@ import (
"context"
"fmt"
"io"
+ "net/http"
"os"
"time"
- "github.com/gorilla/mux"
+ "go.mau.fi/util/exhttp"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
@@ -24,8 +25,10 @@ import (
)
type MatrixCapabilities struct {
- AutoJoinInvites bool
- BatchSending bool
+ AutoJoinInvites bool
+ BatchSending bool
+ ArbitraryMemberChange bool
+ ExtraProfileMeta bool
}
type MatrixConnector interface {
@@ -58,32 +61,55 @@ type MatrixConnector interface {
ServerName() string
}
+type MatrixConnectorWithArbitraryRoomState interface {
+ MatrixConnector
+ GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error)
+}
+
type MatrixConnectorWithServer interface {
+ MatrixConnector
GetPublicAddress() string
- GetRouter() *mux.Router
+ GetRouter() *http.ServeMux
+}
+
+type IProvisioningAPI interface {
+ GetRouter() *http.ServeMux
+ GetUser(r *http.Request) *User
+}
+
+type MatrixConnectorWithProvisioning interface {
+ MatrixConnector
+ GetProvisioning() IProvisioningAPI
}
type MatrixConnectorWithPublicMedia interface {
+ MatrixConnector
GetPublicMediaAddress(contentURI id.ContentURIString) string
+ GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error)
}
type MatrixConnectorWithNameDisambiguation interface {
+ MatrixConnector
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
}
type MatrixConnectorWithBridgeIdentifier interface {
+ MatrixConnector
GetUniqueBridgeID() string
}
type MatrixConnectorWithURLPreviews interface {
+ MatrixConnector
GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error)
}
type MatrixConnectorWithPostRoomBridgeHandling interface {
+ MatrixConnector
HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error
}
type MatrixConnectorWithAnalytics interface {
+ MatrixConnector
TrackAnalytics(userID id.UserID, event string, properties map[string]any)
}
@@ -98,9 +124,15 @@ type DirectNotificationData struct {
}
type MatrixConnectorWithNotifications interface {
+ MatrixConnector
DisplayNotification(ctx context.Context, data *DirectNotificationData)
}
+type MatrixConnectorWithHTTPSettings interface {
+ MatrixConnector
+ GetHTTPClientSettings() exhttp.ClientSettings
+}
+
type MatrixSendExtra struct {
Timestamp time.Time
MessageMeta *database.Message
@@ -144,6 +176,10 @@ func (ce CallbackError) Unwrap() error {
return ce.Wrapped
}
+type EnsureJoinedParams struct {
+ Via []string
+}
+
type MatrixAPI interface {
GetMXID() id.UserID
IsDoublePuppet() bool
@@ -164,17 +200,26 @@ type MatrixAPI interface {
CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error)
DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error
- EnsureJoined(ctx context.Context, roomID id.RoomID) error
+ EnsureJoined(ctx context.Context, roomID id.RoomID, params ...EnsureJoinedParams) error
EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error
TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error
MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error
+
+ GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error)
}
type StreamOrderReadingMatrixAPI interface {
+ MatrixAPI
MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error
}
type MarkAsDMMatrixAPI interface {
+ MatrixAPI
MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error
}
+
+type EphemeralSendingMatrixAPI interface {
+ MatrixAPI
+ BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error)
+}
diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go
index bfbabd26..75c00cb0 100644
--- a/bridgev2/matrixinvite.go
+++ b/bridgev2/matrixinvite.go
@@ -88,6 +88,36 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI,
rejectInvite(ctx, evt, intent, "")
}
+func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) {
+ if portal.MXID == "" {
+ return
+ }
+ log := zerolog.Ctx(ctx)
+ existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID)
+ if err != nil {
+ log.Err(err).
+ Stringer("old_portal_mxid", portal.MXID).
+ Msg("Failed to check existing portal members, deleting room")
+ } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok {
+ log.Debug().
+ Stringer("old_portal_mxid", portal.MXID).
+ Msg("Inviter has no member event in old portal, deleting room")
+ } else if targetUserMember.Membership.IsInviteOrJoin() {
+ return
+ } else {
+ log.Debug().
+ Stringer("old_portal_mxid", portal.MXID).
+ Str("membership", string(targetUserMember.Membership)).
+ Msg("Inviter is not in old portal, deleting room")
+ }
+
+ if err = portal.RemoveMXID(ctx); err != nil {
+ log.Err(err).Msg("Failed to delete old portal mxid")
+ } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
+ log.Err(err).Msg("Failed to clean up old portal room")
+ }
+}
+
func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult {
ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey()))
validator, ok := br.Network.(IdentifierValidatingNetwork)
@@ -165,34 +195,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
}
- if portal.MXID != "" {
- doCleanup := true
- existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID)
- if err != nil {
- log.Err(err).
- Stringer("old_portal_mxid", portal.MXID).
- Msg("Failed to check existing portal members, deleting room")
- } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok {
- log.Debug().
- Stringer("old_portal_mxid", portal.MXID).
- Msg("Inviter has no member event in old portal, deleting room")
- } else if targetUserMember.Membership.IsInviteOrJoin() {
- doCleanup = false
- } else {
- log.Debug().
- Stringer("old_portal_mxid", portal.MXID).
- Str("membership", string(targetUserMember.Membership)).
- Msg("Inviter is not in old portal, deleting room")
- }
-
- if doCleanup {
- if err = portal.RemoveMXID(ctx); err != nil {
- log.Err(err).Msg("Failed to delete old portal mxid")
- } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
- log.Err(err).Msg("Failed to clean up old portal room")
- }
- }
- }
+ portal.CleanupOrphanedDM(ctx, sender.MXID)
err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID())
if err != nil {
log.Err(err).Msg("Failed to ensure bot is invited to room")
@@ -206,72 +209,67 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
- didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID)
- if didSetPortal {
- message := "Private chat portal created"
- err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
- hasWarning := false
- if err != nil {
- log.Warn().Err(err).Msg("Failed to give power to bot in new DM")
- message += "\n\nWarning: failed to promote bot"
- hasWarning = true
- }
- if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
- log.Debug().
- Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
- Msg("Created DM was redirected to another user ID")
- _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
- Parsed: &event.MemberEventContent{
- Membership: event.MembershipLeave,
- Reason: "Direct chat redirected to another internal user ID",
- },
- }, time.Time{})
- if err != nil {
- log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
- }
- otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo)
- if err != nil {
- log.Err(err).Msg("Failed to get ghost of real portal other user ID")
- } else {
- invitedGhost = otherUserGhost
- }
- }
- if resp.PortalInfo != nil {
- portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{})
- } else {
- portal.UpdateCapabilities(ctx, sourceLogin, true)
- portal.UpdateBridgeInfo(ctx)
- }
- // TODO this might become unnecessary if UpdateInfo starts taking care of it
- _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
- Parsed: &event.ElementFunctionalMembersContent{
- ServiceMembers: []id.UserID{br.Bot.GetMXID()},
+ portal.roomCreateLock.Lock()
+ defer portal.roomCreateLock.Unlock()
+ portalMXID := portal.MXID
+ if portalMXID != "" {
+ sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL())
+ rejectInvite(ctx, evt, br.Bot, "")
+ return EventHandlingResultSuccess
+ }
+ err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
+ if err != nil {
+ log.Err(err).Msg("Failed to give permissions to bridge bot")
+ sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot")
+ rejectInvite(ctx, evt, br.Bot, "")
+ return EventHandlingResultSuccess
+ }
+ overrideIntent := invitedGhost.Intent
+ if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
+ log.Debug().
+ Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
+ Msg("Created DM was redirected to another user ID")
+ _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
+ Parsed: &event.MemberEventContent{
+ Membership: event.MembershipLeave,
+ Reason: "Direct chat redirected to another internal user ID",
},
}, time.Time{})
if err != nil {
- log.Warn().Err(err).Msg("Failed to set service members in room")
- if !hasWarning {
- message += "\n\nWarning: failed to set service members"
- hasWarning = true
- }
+ log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
}
- mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
- if ok {
- err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
- if err != nil {
- if hasWarning {
- message += fmt.Sprintf(", %s", err.Error())
- } else {
- message += fmt.Sprintf("\n\nWarning: %s", err.Error())
- }
- }
+ if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot {
+ overrideIntent = br.Bot
+ } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil {
+ log.Err(err).Msg("Failed to get ghost of real portal other user ID")
+ } else {
+ invitedGhost = otherUserGhost
+ overrideIntent = otherUserGhost.Intent
}
- sendNotice(ctx, evt, invitedGhost.Intent, message)
- } else {
- // TODO ensure user is invited even if PortalInfo wasn't provided?
- sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL())
- rejectInvite(ctx, evt, br.Bot, "")
}
+ err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{
+ // We locked it before checking the mxid
+ RoomCreateAlreadyLocked: true,
+
+ FailIfMXIDSet: true,
+ ChatInfo: resp.PortalInfo,
+ ChatInfoSource: sourceLogin,
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to update Matrix room ID for new DM portal")
+ sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work")
+ return EventHandlingResultSuccess
+ }
+ message := "Private chat portal created"
+ mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
+ if ok {
+ err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
+ if err != nil {
+ log.Err(err).Msg("Error in connector newly bridged room handler")
+ message += fmt.Sprintf("\n\nWarning: %s", err.Error())
+ }
+ }
+ sendNotice(ctx, evt, overrideIntent, message)
return EventHandlingResultSuccess
}
@@ -294,21 +292,3 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith
}
return nil
}
-
-func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
- portal.roomCreateLock.Lock()
- defer portal.roomCreateLock.Unlock()
- if portal.MXID != "" {
- return false
- }
- portal.MXID = roomID
- portal.updateLogger()
- portal.Bridge.cacheLock.Lock()
- portal.Bridge.portalsByMXID[portal.MXID] = portal
- portal.Bridge.cacheLock.Unlock()
- err := portal.Save(ctx)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid")
- }
- return true
-}
diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go
index 7118649d..df0c9e4d 100644
--- a/bridgev2/messagestatus.go
+++ b/bridgev2/messagestatus.go
@@ -20,6 +20,7 @@ import (
type MessageStatusEventInfo struct {
RoomID id.RoomID
+ TransactionID string
SourceEventID id.EventID
NewEventID id.EventID
EventType event.Type
@@ -41,6 +42,7 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo {
return &MessageStatusEventInfo{
RoomID: evt.RoomID,
+ TransactionID: evt.Unsigned.TransactionID,
SourceEventID: evt.ID,
EventType: evt.Type,
MessageType: evt.Content.AsMessage().MsgType,
@@ -182,9 +184,10 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe
Type: event.RelReference,
EventID: evt.SourceEventID,
},
- Status: ms.Status,
- Reason: ms.ErrorReason,
- Message: ms.Message,
+ TargetTxnID: evt.TransactionID,
+ Status: ms.Status,
+ Reason: ms.ErrorReason,
+ Message: ms.Message,
}
if ms.InternalError != nil {
content.InternalError = ms.InternalError.Error()
diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go
index 443d3655..e3a6df70 100644
--- a/bridgev2/networkid/bridgeid.go
+++ b/bridgev2/networkid/bridgeid.go
@@ -47,8 +47,8 @@ type PortalID string
// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true.
// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user.
type PortalKey struct {
- ID PortalID
- Receiver UserLoginID
+ ID PortalID `json:"portal_id"`
+ Receiver UserLoginID `json:"portal_receiver,omitempty"`
}
func (pk PortalKey) IsEmpty() bool {
diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go
index eb38bd2d..efc5f100 100644
--- a/bridgev2/networkinterface.go
+++ b/bridgev2/networkinterface.go
@@ -16,7 +16,9 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/configupgrade"
"go.mau.fi/util/ptr"
+ "go.mau.fi/util/random"
+ "maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
@@ -117,11 +119,15 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa
mediaPart.Content.EnsureHasHTML()
mediaPart.Content.Body += "\n\n" + textPart.Content.Body
mediaPart.Content.FormattedBody += "
" + textPart.Content.FormattedBody
+ mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions)
+ mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...)
} else {
mediaPart.Content.FileName = mediaPart.Content.Body
mediaPart.Content.Body = textPart.Content.Body
mediaPart.Content.Format = textPart.Content.Format
mediaPart.Content.FormattedBody = textPart.Content.FormattedBody
+ mediaPart.Content.Mentions = textPart.Content.Mentions
+ mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews
}
if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok {
metaMerger.CopyFrom(textPart.DBMetadata)
@@ -255,6 +261,7 @@ type NetworkConnector interface {
}
type StoppableNetwork interface {
+ NetworkConnector
// Stop is called when the bridge is stopping, after all network clients have been disconnected.
Stop()
}
@@ -311,6 +318,16 @@ type MaxFileSizeingNetwork interface {
SetMaxFileSize(maxSize int64)
}
+type NetworkResettingNetwork interface {
+ NetworkConnector
+ // ResetHTTPTransport should recreate the HTTP client used by the bridge.
+ // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable.
+ ResetHTTPTransport()
+ // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections.
+ // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here.
+ ResetNetworkConnections()
+}
+
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
type MatrixMessageResponse struct {
@@ -342,10 +359,16 @@ type NetworkGeneralCapabilities struct {
// Should the bridge re-request user info on incoming messages even if the ghost already has info?
// By default, info is only requested for ghosts with no name, and other updating is left to events.
AggressiveUpdateInfo bool
+ // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message?
+ // This should be enabled if the network requires each message to be marked as read independently,
+ // and doesn't automatically do it when sending a message.
+ ImplicitReadReceipts bool
// If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave])
// to handle asynchronous message responses, this field can be set to enable
// automatic timeout errors in case the asynchronous response never arrives.
OutgoingMessageTimeouts *OutgoingTimeoutConfig
+ // Capabilities related to the provisioning API.
+ Provisioning ProvisioningCapabilities
}
// NetworkAPI is an interface representing a remote network client for a single user login.
@@ -679,6 +702,35 @@ type RoomTopicHandlingNetworkAPI interface {
HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error)
}
+type DisappearTimerChangingNetworkAPI interface {
+ NetworkAPI
+ // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed.
+ // This method should update the Disappear field of the Portal with the new timer and return true
+ // if the change was successful. If the change is not successful, then the field should not be updated.
+ HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error)
+}
+
+// DeleteChatHandlingNetworkAPI is an optional interface that network connectors
+// can implement to delete a chat from the remote network.
+type DeleteChatHandlingNetworkAPI interface {
+ NetworkAPI
+ // HandleMatrixDeleteChat is called when the user explicitly deletes a chat.
+ HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error
+}
+
+// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors
+// can implement to accept message requests from the remote network.
+type MessageRequestAcceptingNetworkAPI interface {
+ NetworkAPI
+ // HandleMatrixAcceptMessageRequest is called when the user accepts a message request.
+ HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error
+}
+
+type BeeperAIStreamHandlingNetworkAPI interface {
+ NetworkAPI
+ HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error
+}
+
type ResolveIdentifierResponse struct {
// Ghost is the ghost of the user that the identifier resolves to.
// This field should be set whenever possible. However, it is not required,
@@ -698,6 +750,8 @@ type ResolveIdentifierResponse struct {
Chat *CreateChatResponse
}
+var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10))
+
type CreateChatResponse struct {
PortalKey networkid.PortalKey
// Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary.
@@ -706,6 +760,17 @@ type CreateChatResponse struct {
// If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user,
// this field should have the user ID of said different user.
DMRedirectedTo networkid.UserID
+
+ FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant
+}
+
+type CreateChatFailedParticipant struct {
+ Reason string `json:"reason"`
+ InviteEventType string `json:"invite_event_type,omitempty"`
+ InviteContent *event.Content `json:"invite_content,omitempty"`
+
+ UserMXID id.UserID `json:"user_mxid,omitempty"`
+ DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"`
}
// IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats.
@@ -740,7 +805,83 @@ type UserSearchingNetworkAPI interface {
type GroupCreatingNetworkAPI interface {
IdentifierResolvingNetworkAPI
- CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error)
+ CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error)
+}
+
+type PersonalFilteringCustomizingNetworkAPI interface {
+ NetworkAPI
+ CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom)
+}
+
+type ProvisioningCapabilities struct {
+ ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"`
+ GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"`
+}
+
+type ResolveIdentifierCapabilities struct {
+ // Can DMs be created after resolving an identifier?
+ CreateDM bool `json:"create_dm"`
+ // Can users be looked up by phone number?
+ LookupPhone bool `json:"lookup_phone"`
+ // Can users be looked up by email address?
+ LookupEmail bool `json:"lookup_email"`
+ // Can users be looked up by network-specific username?
+ LookupUsername bool `json:"lookup_username"`
+ // Can any phone number be contacted without having to validate it via lookup first?
+ AnyPhone bool `json:"any_phone"`
+ // Can a contact list be retrieved from the bridge?
+ ContactList bool `json:"contact_list"`
+ // Can users be searched by name on the remote network?
+ Search bool `json:"search"`
+}
+
+type GroupTypeCapabilities struct {
+ TypeDescription string `json:"type_description"`
+
+ Name GroupFieldCapability `json:"name"`
+ Username GroupFieldCapability `json:"username"`
+ Avatar GroupFieldCapability `json:"avatar"`
+ Topic GroupFieldCapability `json:"topic"`
+ Disappear GroupFieldCapability `json:"disappear"`
+ Participants GroupFieldCapability `json:"participants"`
+ Parent GroupFieldCapability `json:"parent"`
+}
+
+type GroupFieldCapability struct {
+ // Is setting this field allowed at all in the create request?
+ // Even if false, the network connector should attempt to set the metadata after group creation,
+ // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room.
+ Allowed bool `json:"allowed"`
+ // Is setting this field mandatory for the creation to succeed?
+ Required bool `json:"required,omitempty"`
+ // The minimum/maximum length of the field, if applicable.
+ // For members, length means the number of members excluding the creator.
+ MinLength int `json:"min_length,omitempty"`
+ MaxLength int `json:"max_length,omitempty"`
+
+ // Only for the disappear field: allowed disappearing settings
+ DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"`
+
+ // This can be used to tell provisionutil not to call ValidateUserID on each participant.
+ // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs.
+ SkipIdentifierValidation bool `json:"-"`
+}
+
+type GroupCreateParams struct {
+ Type string `json:"type,omitempty"`
+
+ Username string `json:"username,omitempty"`
+ // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs
+ Participants []networkid.UserID `json:"participants,omitempty"`
+ Parent *networkid.PortalKey `json:"parent,omitempty"`
+
+ Name *event.RoomNameEventContent `json:"name,omitempty"`
+ Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"`
+ Topic *event.TopicEventContent `json:"topic,omitempty"`
+ Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"`
+
+ // An existing room ID to bridge to. If unset, a new room will be created.
+ RoomID id.RoomID `json:"room_id,omitempty"`
}
type MembershipChangeType struct {
@@ -780,16 +921,15 @@ type MatrixMembershipChange struct {
MatrixRoomMeta[*event.MemberEventContent]
Target GhostOrUserLogin
Type MembershipChangeType
+}
- // Deprecated: Use Target instead
- TargetGhost *Ghost
- // Deprecated: Use Target instead
- TargetUserLogin *UserLogin
+type MatrixMembershipResult struct {
+ RedirectTo networkid.UserID
}
type MembershipHandlingNetworkAPI interface {
NetworkAPI
- HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error)
+ HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error)
}
type SinglePowerLevelChange struct {
@@ -1028,6 +1168,11 @@ type RemoteChatDelete interface {
RemoteDeleteOnlyForMe
}
+type RemoteChatDeleteWithChildren interface {
+ RemoteChatDelete
+ DeleteChildren() bool
+}
+
type RemoteEventThatMayCreatePortal interface {
RemoteEvent
ShouldCreatePortal() bool
@@ -1260,12 +1405,14 @@ type MatrixMessageRemove struct {
type MatrixRoomMeta[ContentType any] struct {
MatrixEventBase[ContentType]
- PrevContent ContentType
+ PrevContent ContentType
+ IsStateRequest bool
}
type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent]
type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent]
type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent]
+type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer]
type MatrixReadReceipt struct {
Portal *Portal
@@ -1280,6 +1427,8 @@ type MatrixReadReceipt struct {
LastRead time.Time
// The receipt metadata.
Receipt event.ReadReceipt
+ // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt.
+ Implicit bool
}
type MatrixTyping struct {
@@ -1293,6 +1442,9 @@ type MatrixViewingChat struct {
Portal *Portal
}
+type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent]
+type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent]
+type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent]
type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent]
type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent]
type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent]
diff --git a/bridgev2/portal.go b/bridgev2/portal.go
index ab1f37f1..16aa703b 100644
--- a/bridgev2/portal.go
+++ b/bridgev2/portal.go
@@ -19,6 +19,7 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/exfmt"
+ "go.mau.fi/util/exmaps"
"go.mau.fi/util/exslices"
"go.mau.fi/util/exsync"
"go.mau.fi/util/ptr"
@@ -85,9 +86,15 @@ type Portal struct {
lastCapUpdate time.Time
- roomCreateLock sync.Mutex
+ roomCreateLock sync.Mutex
+ cancelRoomCreate atomic.Pointer[context.CancelFunc]
+ RoomCreated *exsync.Event
- events chan portalEvent
+ functionalMembersLock sync.Mutex
+ functionalMembersCache *event.ElementFunctionalMembersContent
+
+ events chan portalEvent
+ deleted *exsync.Event
eventsLock sync.Mutex
eventIdx int
@@ -119,7 +126,15 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
currentlyTypingLogins: make(map[id.UserID]*UserLogin),
currentlyTypingGhosts: exsync.NewSet[id.UserID](),
outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage),
+
+ RoomCreated: exsync.NewEvent(),
+ deleted: exsync.NewEvent(),
}
+ if portal.MXID != "" {
+ portal.RoomCreated.Set()
+ }
+ // Putting the portal in the cache before it's fully initialized is mildly dangerous,
+ // but loading the relay user login may depend on it.
br.portalsByKey[portal.PortalKey] = portal
if portal.MXID != "" {
br.portalsByMXID[portal.MXID] = portal
@@ -128,12 +143,20 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
if portal.ParentKey.ID != "" {
portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false)
if err != nil {
+ delete(br.portalsByKey, portal.PortalKey)
+ if portal.MXID != "" {
+ delete(br.portalsByMXID, portal.MXID)
+ }
return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err)
}
}
if portal.RelayLoginID != "" {
portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID)
if err != nil {
+ delete(br.portalsByKey, portal.PortalKey)
+ if portal.MXID != "" {
+ delete(br.portalsByMXID, portal.MXID)
+ }
return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err)
}
}
@@ -146,7 +169,9 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
}
func (portal *Portal) updateLogger() {
- logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID))
+ logWith := portal.Bridge.Log.With().
+ Str("portal_id", string(portal.ID)).
+ Str("portal_receiver", string(portal.Receiver))
if portal.MXID != "" {
logWith = logWith.Stringer("portal_mxid", portal.MXID)
}
@@ -170,6 +195,16 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta
return output, nil
}
+func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) {
+ if dbPortal == nil {
+ return nil, nil
+ } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok {
+ return cached, nil
+ } else {
+ return br.loadPortal(ctx, dbPortal, nil, nil)
+ }
+}
+
func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) {
if br.Config.SplitPortals && key.Receiver == "" {
return nil, fmt.Errorf("receiver must always be set when split portals is enabled")
@@ -259,6 +294,26 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us
return br.loadManyPortals(ctx, rows)
}
+func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) {
+ br.cacheLock.Lock()
+ defer br.cacheLock.Unlock()
+ rows, err := br.DB.Portal.GetChildren(ctx, parent)
+ if err != nil {
+ return nil, err
+ }
+ return br.loadManyPortals(ctx, rows)
+}
+
+func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) {
+ br.cacheLock.Lock()
+ defer br.cacheLock.Unlock()
+ dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID)
+ if err != nil {
+ return nil, err
+ }
+ return br.loadPortalWithCacheCheck(ctx, dbPortal)
+}
+
func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
@@ -284,15 +339,23 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port
}
func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult {
+ if portal.deleted.IsSet() {
+ return EventHandlingResultIgnored
+ }
if PortalEventBuffer == 0 {
portal.eventsLock.Lock()
defer portal.eventsLock.Unlock()
portal.eventIdx++
- return portal.handleSingleEventAsync(portal.eventIdx, evt)
+ return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt)
} else {
+ if portal.events == nil {
+ panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey))
+ }
select {
case portal.events <- evt:
return EventHandlingResultQueued
+ case <-portal.deleted.GetChan():
+ return EventHandlingResultIgnored
default:
zerolog.Ctx(ctx).Error().
Str("portal_id", string(portal.ID)).
@@ -317,64 +380,71 @@ func (portal *Portal) eventLoop() {
go portal.pendingMessageTimeoutLoop(ctx, cfg)
defer cancel()
}
- i := 0
- for rawEvt := range portal.events {
- i++
- portal.handleSingleEventAsync(i, rawEvt)
+ deleteCh := portal.deleted.GetChan()
+ for i := 0; ; i++ {
+ select {
+ case rawEvt := <-portal.events:
+ if rawEvt == nil {
+ return
+ }
+ if portal.Bridge.Config.AsyncEvents {
+ go portal.handleSingleEventWithDelayLogging(i, rawEvt)
+ } else {
+ portal.handleSingleEventWithDelayLogging(i, rawEvt)
+ }
+ case <-deleteCh:
+ return
+ }
}
}
-func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) {
+func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) {
ctx := portal.getEventCtxWithLog(rawEvt, idx)
- if _, isCreate := rawEvt.(*portalCreateEvent); isCreate {
- portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {
- outerRes = res
- })
- } else if portal.Bridge.Config.AsyncEvents {
- outerRes = EventHandlingResultQueued
- go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {})
- } else {
- log := zerolog.Ctx(ctx)
- doneCh := make(chan struct{})
- var backgrounded atomic.Bool
- start := time.Now()
- var handleDuration time.Duration
- // Note: this will not set the success flag if the handler times out
- outerRes = EventHandlingResult{Queued: true}
- go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {
- outerRes = res
- handleDuration = time.Since(start)
- close(doneCh)
- if backgrounded.Load() {
+ log := zerolog.Ctx(ctx)
+ doneCh := make(chan struct{})
+ var backgrounded atomic.Bool
+ start := time.Now()
+ var handleDuration time.Duration
+ // Note: this will not set the success flag if the handler times out
+ outerRes = EventHandlingResult{Queued: true}
+ go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {
+ outerRes = res
+ handleDuration = time.Since(start)
+ close(doneCh)
+ if backgrounded.Load() {
+ log.Debug().
+ Time("started_at", start).
+ Stringer("duration", handleDuration).
+ Msg("Event that took too long finally finished handling")
+ }
+ })
+ tick := time.NewTicker(30 * time.Second)
+ _, isCreate := rawEvt.(*portalCreateEvent)
+ defer tick.Stop()
+ for i := 0; i < 10; i++ {
+ select {
+ case <-doneCh:
+ if i > 0 {
log.Debug().
Time("started_at", start).
Stringer("duration", handleDuration).
- Msg("Event that took too long finally finished handling")
+ Msg("Event that took long finished handling")
}
- })
- tick := time.NewTicker(30 * time.Second)
- defer tick.Stop()
- for i := 0; i < 10; i++ {
- select {
- case <-doneCh:
- if i > 0 {
- log.Debug().
- Time("started_at", start).
- Stringer("duration", handleDuration).
- Msg("Event that took long finished handling")
- }
- return
- case <-tick.C:
- log.Warn().
- Time("started_at", start).
- Msg("Event handling is taking long")
+ return
+ case <-tick.C:
+ log.Warn().
+ Time("started_at", start).
+ Msg("Event handling is taking long")
+ if isCreate {
+ // Never background portal creation events
+ i = 1
}
}
- log.Warn().
- Time("started_at", start).
- Msg("Event handling is taking too long, continuing in background")
- backgrounded.Store(true)
}
+ log.Warn().
+ Time("started_at", start).
+ Msg("Event handling is taking too long, continuing in background")
+ backgrounded.Store(true)
return
}
@@ -416,6 +486,11 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context {
logWith = logWith.Int64("remote_stream_order", remoteStreamOrder)
}
}
+ if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok {
+ if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() {
+ logWith = logWith.Time("remote_timestamp", remoteTimestamp)
+ }
+ }
case *portalCreateEvent:
return evt.ctx
}
@@ -455,7 +530,14 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
}()
switch evt := rawEvt.(type) {
case *portalMatrixEvent:
- res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt)
+ isStateRequest := evt.evt.Type == event.BeeperSendState
+ if isStateRequest {
+ if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil {
+ portal.sendErrorStatus(ctx, evt.evt, err)
+ return
+ }
+ }
+ res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest)
if res.SendMSS {
if res.Error != nil {
portal.sendErrorStatus(ctx, evt.evt, res.Error)
@@ -463,6 +545,21 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
portal.sendSuccessStatus(ctx, evt.evt, 0, "")
}
}
+ if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil {
+ portal.revertRoomMeta(ctx, evt.evt)
+ }
+ if isStateRequest && res.Success && !res.SkipStateEcho {
+ portal.sendRoomMeta(
+ ctx,
+ evt.sender.DoublePuppet(ctx),
+ time.UnixMilli(evt.evt.Timestamp),
+ evt.evt.Type,
+ evt.evt.GetStateKey(),
+ evt.evt.Content.Parsed,
+ false,
+ evt.evt.Content.Raw,
+ )
+ }
case *portalRemoteEvent:
res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt)
case *portalCreateEvent:
@@ -474,18 +571,44 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
}
}
+func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
+ content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent)
+ if !ok {
+ return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)
+ }
+ evt.Content = content.Content
+ evt.StateKey = &content.StateKey
+ evt.Type = event.Type{Type: content.Type, Class: event.StateEventType}
+ _ = evt.Content.ParseRaw(evt.Type)
+ mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ return fmt.Errorf("matrix connector doesn't support fetching state")
+ }
+ prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey())
+ if err != nil && !errors.Is(err, mautrix.MNotFound) {
+ return fmt.Errorf("failed to get prev event: %w", err)
+ } else if prevEvt != nil {
+ evt.Unsigned.PrevContent = &prevEvt.Content
+ evt.Unsigned.PrevSender = prevEvt.Sender
+ }
+ return nil
+}
+
func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) {
if portal.Receiver != "" {
login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver)
if err != nil {
return nil, nil, err
}
- if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() {
+ if login == nil {
+ return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn)
+ } else if !login.Client.IsLoggedIn() {
+ return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn)
+ } else if login.UserMXID != user.MXID {
if allowRelay && portal.Relay != nil {
return nil, nil, nil
}
- // TODO different error for this case?
- return nil, nil, ErrNotLoggedIn
+ return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID)
}
up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey)
return login, up, err
@@ -568,7 +691,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID,
var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"}
-func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
+func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
log := zerolog.Ctx(ctx)
if evt.Mautrix.EventSource&event.SourceEphemeral != 0 {
switch evt.Type {
@@ -576,11 +699,17 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
return portal.handleMatrixReceipts(ctx, evt)
case event.EphemeralEventTyping:
return portal.handleMatrixTyping(ctx, evt)
+ case event.BeeperEphemeralEventAIStream:
+ return portal.handleMatrixAIStream(ctx, sender, evt)
default:
return EventHandlingResultIgnored
}
}
- login, _, err := portal.FindPreferredLogin(ctx, sender, true)
+ if evt.Type == event.StateTombstone {
+ // Tombstones aren't bridged so they don't need a login
+ return portal.handleMatrixTombstone(ctx, evt)
+ }
+ login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true)
if err != nil {
log.Err(err).Msg("Failed to get user login to handle Matrix event")
if errors.Is(err, ErrNotLoggedIn) {
@@ -596,6 +725,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
}
var origSender *OrigSender
if login == nil {
+ if isStateRequest {
+ return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest)
+ }
login = portal.Relay
origSender = &OrigSender{
User: sender,
@@ -639,6 +771,21 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
}
// Copy logger because many of the handlers will use UpdateContext
ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx)
+
+ if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() {
+ rrLog := log.With().Str("subaction", "implicit read receipt").Logger()
+ rrCtx := rrLog.WithContext(ctx)
+ rrLog.Debug().Msg("Sending implicit read receipt for event")
+ evtTS := time.UnixMilli(evt.Timestamp)
+ portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{
+ Portal: portal,
+ EventID: evt.ID,
+ Implicit: true,
+ ReadUpTo: evtTS,
+ Receipt: event.ReadReceipt{Timestamp: evtTS},
+ }, userPortal)
+ }
+
switch evt.Type {
case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse:
return portal.handleMatrixMessage(ctx, login, origSender, evt)
@@ -651,11 +798,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
case event.EventRedaction:
return portal.handleMatrixRedaction(ctx, login, origSender, evt)
case event.StateRoomName:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName)
case event.StateTopic:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic)
case event.StateRoomAvatar:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar)
+ case event.StateBeeperDisappearingTimer:
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer)
case event.StateEncryption:
// TODO?
return EventHandlingResultIgnored
@@ -666,9 +815,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
case event.AccountDataBeeperMute:
return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute)
case event.StateMember:
- return portal.handleMatrixMembership(ctx, login, origSender, evt)
+ return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest)
case event.StatePowerLevels:
- return portal.handleMatrixPowerLevels(ctx, login, origSender, evt)
+ return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest)
+ case event.BeeperDeleteChat:
+ return portal.handleMatrixDeleteChat(ctx, login, origSender, evt)
+ case event.BeeperAcceptMessageRequest:
+ return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt)
default:
return EventHandlingResultIgnored
}
@@ -688,7 +841,7 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event
sender, err := portal.Bridge.GetUserByMXID(ctx, userID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt)
}
@@ -726,15 +879,10 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e
EventID: eventID,
Receipt: receipt,
}
- if userPortal == nil {
- userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey)
- } else {
- evt.LastRead = userPortal.LastRead
- userPortal = userPortal.CopyWithoutValues()
- }
evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID)
if err != nil {
log.Err(err).Msg("Failed to get exact message from database")
+ evt.ReadUpTo = receipt.Timestamp
} else if evt.ExactMessage != nil {
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp)
@@ -743,21 +891,40 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e
} else {
evt.ReadUpTo = receipt.Timestamp
}
- err = rrClient.HandleMatrixReadReceipt(ctx, evt)
+ portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal)
+}
+
+func (portal *Portal) callReadReceiptHandler(
+ ctx context.Context,
+ login *UserLogin,
+ rrClient ReadReceiptHandlingNetworkAPI,
+ evt *MatrixReadReceipt,
+ userPortal *database.UserPortal,
+) {
+ if rrClient == nil {
+ var ok bool
+ rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI)
+ if !ok {
+ return
+ }
+ }
+ if userPortal == nil {
+ userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey)
+ } else {
+ evt.LastRead = userPortal.LastRead
+ userPortal = userPortal.CopyWithoutValues()
+ }
+ err := rrClient.HandleMatrixReadReceipt(ctx, evt)
if err != nil {
- log.Err(err).Msg("Failed to handle read receipt")
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt")
return
}
- if evt.ExactMessage != nil {
- userPortal.LastRead = evt.ExactMessage.Timestamp
- } else {
- userPortal.LastRead = receipt.Timestamp
- }
+ userPortal.LastRead = evt.ReadUpTo
err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal)
if err != nil {
- log.Err(err).Msg("Failed to save user portal metadata")
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata")
}
- portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID)
+ portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo)
}
func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -778,6 +945,50 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event)
return EventHandlingResultSuccess
}
+func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
+ log := zerolog.Ctx(ctx)
+ if sender == nil {
+ log.Error().Msg("Missing sender for Matrix AI stream event")
+ return EventHandlingResultIgnored
+ }
+ login, _, err := portal.FindPreferredLogin(ctx, sender, true)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ var origSender *OrigSender
+ if login == nil {
+ if portal.Relay == nil {
+ return EventHandlingResultIgnored
+ }
+ login = portal.Relay
+ origSender = &OrigSender{
+ User: sender,
+ UserID: sender.MXID,
+ }
+ }
+ content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent)
+ if !ok {
+ log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
+ }
+ api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI)
+ if !ok {
+ return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported)
+ }
+ err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{
+ Event: evt,
+ Content: content,
+ Portal: portal,
+ OrigSender: origSender,
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to handle Matrix AI stream event")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ return EventHandlingResultSuccess.WithMSS()
+}
+
func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) {
for _, userID := range userIDs {
login, ok := portal.currentlyTypingLogins[userID]
@@ -886,8 +1097,18 @@ func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content
feat.Caption.Reject() {
return ErrCaptionsNotAllowed
}
- if content.Info != nil && content.Info.MimeType != "" {
- if feat.GetMimeSupport(content.Info.MimeType).Reject() {
+ if content.Info != nil {
+ dur := time.Duration(content.Info.Duration) * time.Millisecond
+ if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration {
+ if capMsgType == event.CapMsgVoice {
+ return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration))
+ }
+ return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration))
+ }
+ if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize {
+ return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024)
+ }
+ if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() {
return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType)
}
}
@@ -947,10 +1168,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
log.Debug().Msg("Ignoring poll event from relayed user")
return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser)
}
- msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender)
- if err != nil {
- log.Err(err).Msg("Failed to format message for relaying")
- return EventHandlingResultFailed.WithMSSError(err)
+ if !caps.PerMessageProfileRelay {
+ msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender)
+ if err != nil {
+ log.Err(err).Msg("Failed to format message for relaying")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
}
}
if msgContent != nil {
@@ -1018,6 +1241,16 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
}
+ var messageTimer *event.BeeperDisappearingTimer
+ if msgContent != nil {
+ messageTimer = msgContent.BeeperDisappearingTimer
+ }
+ if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer {
+ log.Warn().
+ Any("event_timer", messageTimer).
+ Any("portal_timer", portal.Disappear.ToEventContent()).
+ Msg("Mismatching disappearing timer in event")
+ }
wrappedMsgEvt := &MatrixMessage{
MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{
@@ -1043,6 +1276,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
+ err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps)
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to auto-accept message request on message")
+ // TODO stop processing?
+ }
+
var resp *MatrixMessageResponse
if msgContent != nil {
resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt)
@@ -1094,22 +1333,23 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID)
}
- if portal.Disappear.Type != database.DisappearingTypeNone {
+ ds := portal.Disappear
+ if messageTimer != nil {
+ ds = database.DisappearingSettingFromEvent(messageTimer)
+ }
+ if ds.Type != event.DisappearingTypeNone {
go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{
- RoomID: portal.MXID,
- EventID: message.MXID,
- DisappearingSetting: database.DisappearingSetting{
- Type: portal.Disappear.Type,
- Timer: portal.Disappear.Timer,
- DisappearAt: message.Timestamp.Add(portal.Disappear.Timer),
- },
+ RoomID: portal.MXID,
+ EventID: message.MXID,
+ Timestamp: message.Timestamp,
+ DisappearingSetting: ds.StartingAt(message.Timestamp),
})
}
if resp.Pending {
// Not exactly queued, but not finished either
return EventHandlingResultQueued
}
- return EventHandlingResultSuccess
+ return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder)
}
// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message.
@@ -1298,7 +1538,7 @@ func (portal *Portal) handleMatrixEdit(
return EventHandlingResultSuccess
}
-func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult {
+func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) {
log := zerolog.Ctx(ctx)
reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI)
if !ok {
@@ -1321,6 +1561,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
log.Warn().Msg("Reaction target message not found in database")
return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound))
}
+ caps := sender.Client.GetCapabilities(ctx, portal)
+ err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps)
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction")
+ // TODO stop processing?
+ }
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("reaction_target_remote_id", string(reactionTarget.ID))
})
@@ -1343,6 +1589,31 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if portal.Bridge.Config.OutgoingMessageReID {
deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID)
}
+ defer func() {
+ // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction
+ if handleRes.Success {
+ portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
+ }
+ }()
+ removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) {
+ if !handleRes.Success {
+ return
+ }
+ _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
+ Parsed: &event.RedactionEventContent{
+ Redacts: oldReact.MXID,
+ },
+ }, nil)
+ if err != nil {
+ log.Err(err).Msg("Failed to remove old reaction")
+ }
+ if deleteDB {
+ err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact)
+ if err != nil {
+ log.Err(err).Msg("Failed to delete old reaction from database")
+ }
+ }
+ }
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
@@ -1351,17 +1622,10 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if existing.EmojiID != "" || existing.Emoji == preResp.Emoji {
log.Debug().Msg("Ignoring duplicate reaction")
portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
- return EventHandlingResultIgnored
+ return EventHandlingResultIgnored.WithEventID(deterministicID)
}
react.ReactionToOverride = existing
- _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
- Parsed: &event.RedactionEventContent{
- Redacts: existing.MXID,
- },
- }, nil)
- if err != nil {
- log.Err(err).Msg("Failed to remove old reaction")
- }
+ defer removeOutdatedReaction(existing, false)
}
react.PreHandleResp = &preResp
if preResp.MaxReactions > 0 {
@@ -1376,18 +1640,14 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
// Keep n-1 previous reactions and remove the rest
react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1]
for _, oldReaction := range allReactions[preResp.MaxReactions-1:] {
- _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
- Parsed: &event.RedactionEventContent{
- Redacts: oldReaction.MXID,
- },
- }, nil)
- if err != nil {
- log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded")
- }
- err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction)
- if err != nil {
- log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded")
+ if existing != nil && oldReaction.EmojiID == existing.EmojiID {
+ // Don't double-delete on networks that only allow one emoji
+ continue
}
+ // Intentionally defer in a loop, there won't be that many items,
+ // and we want all of them to be done after this function completes successfully
+ //goland:noinspection GoDeferInLoop
+ defer removeOutdatedReaction(oldReaction, true)
}
}
}
@@ -1432,8 +1692,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if err != nil {
log.Err(err).Msg("Failed to save reaction to database")
}
- portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
- return EventHandlingResultSuccess
+ return EventHandlingResultSuccess.WithEventID(deterministicID)
}
func handleMatrixRoomMeta[APIType any, ContentType any](
@@ -1442,11 +1701,19 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
+ isStateRequest bool,
fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error),
) EventHandlingResult {
+ if evt.StateKey == nil || *evt.StateKey != "" {
+ return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
+ }
+ //caps := sender.Client.GetCapabilities(ctx, portal)
+ //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported {
+ // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed))
+ //}
api, ok := sender.Client.(APIType)
if !ok {
- return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported)
+ return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type))
}
log := zerolog.Ctx(ctx)
content, ok := evt.Content.Parsed.(ContentType)
@@ -1470,6 +1737,18 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
portal.sendSuccessStatus(ctx, evt, 0, "")
return EventHandlingResultIgnored
}
+ case *event.BeeperDisappearingTimer:
+ if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 {
+ typedContent.Type = event.DisappearingTypeNone
+ typedContent.Timer.Duration = 0
+ }
+ if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer {
+ portal.sendSuccessStatus(ctx, evt, 0, "")
+ return EventHandlingResultIgnored
+ }
+ if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) {
+ return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported)
+ }
}
var prevContent ContentType
if evt.Unsigned.PrevContent != nil {
@@ -1486,14 +1765,17 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- PrevContent: prevContent,
+ IsStateRequest: isStateRequest,
+ PrevContent: prevContent,
})
if err != nil {
log.Err(err).Msg("Failed to handle Matrix room metadata")
return EventHandlingResultFailed.WithMSSError(err)
}
if changed {
- portal.UpdateBridgeInfo(ctx)
+ if evt.Type != event.StateBeeperDisappearingTimer {
+ portal.UpdateBridgeInfo(ctx)
+ }
err = portal.Save(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal after updating room metadata")
@@ -1554,12 +1836,139 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos
}
}
-func (portal *Portal) handleMatrixMembership(
+func (portal *Portal) handleMatrixAcceptMessageRequest(
ctx context.Context,
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
) EventHandlingResult {
+ if origSender != nil {
+ return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser)
+ }
+ log := zerolog.Ctx(ctx)
+ content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent)
+ if !ok {
+ log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
+ }
+ api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI)
+ if !ok {
+ return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported)
+ }
+ err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{
+ Event: evt,
+ Content: content,
+ Portal: portal,
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to handle Matrix accept message request")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ if portal.MessageRequest {
+ portal.MessageRequest = false
+ portal.UpdateBridgeInfo(ctx)
+ err = portal.Save(ctx)
+ if err != nil {
+ log.Err(err).Msg("Failed to save portal after accepting message request")
+ }
+ }
+ return EventHandlingResultSuccess.WithMSS()
+}
+
+func (portal *Portal) autoAcceptMessageRequest(
+ ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures,
+) error {
+ if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported {
+ return nil
+ }
+ mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI)
+ if !ok {
+ return nil
+ }
+ err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{
+ Event: evt,
+ Content: &event.BeeperAcceptMessageRequestEventContent{
+ IsImplicit: true,
+ },
+ Portal: portal,
+ OrigSender: origSender,
+ })
+ if err != nil {
+ return err
+ }
+ if portal.MessageRequest {
+ portal.MessageRequest = false
+ portal.UpdateBridgeInfo(ctx)
+ err = portal.Save(ctx)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request")
+ }
+ }
+ return nil
+}
+
+func (portal *Portal) handleMatrixDeleteChat(
+ ctx context.Context,
+ sender *UserLogin,
+ origSender *OrigSender,
+ evt *event.Event,
+) EventHandlingResult {
+ if origSender != nil {
+ return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser)
+ }
+ log := zerolog.Ctx(ctx)
+ content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent)
+ if !ok {
+ log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
+ }
+ api, ok := sender.Client.(DeleteChatHandlingNetworkAPI)
+ if !ok {
+ return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported)
+ }
+ err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{
+ Event: evt,
+ Content: content,
+ Portal: portal,
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to handle Matrix chat delete")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ if portal.Receiver == "" {
+ _, others, err := portal.findOtherLogins(ctx, sender)
+ if err != nil {
+ log.Err(err).Msg("Failed to check if portal has other logins")
+ return EventHandlingResultFailed.WithError(err)
+ } else if len(others) > 0 {
+ log.Debug().Msg("Not deleting portal after chat delete as other logins are present")
+ return EventHandlingResultSuccess
+ }
+ }
+ err = portal.Delete(ctx)
+ if err != nil {
+ log.Err(err).Msg("Failed to delete portal from database")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false)
+ if err != nil {
+ log.Err(err).Msg("Failed to delete Matrix room")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ // No MSS here as the portal was deleted
+ return EventHandlingResultSuccess
+}
+
+func (portal *Portal) handleMatrixMembership(
+ ctx context.Context,
+ sender *UserLogin,
+ origSender *OrigSender,
+ evt *event.Event,
+ isStateRequest bool,
+) EventHandlingResult {
+ if evt.StateKey == nil {
+ return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
+ }
log := zerolog.Ctx(ctx)
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
if !ok {
@@ -1595,7 +2004,6 @@ func (portal *Portal) handleMatrixMembership(
return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent)
}
targetGhost, _ := target.(*Ghost)
- targetUserLogin, _ := target.(*UserLogin)
membershipChange := &MatrixMembershipChange{
MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{
MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{
@@ -1606,19 +2014,60 @@ func (portal *Portal) handleMatrixMembership(
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- PrevContent: prevContent,
+ IsStateRequest: isStateRequest,
+ PrevContent: prevContent,
},
- Target: target,
- TargetGhost: targetGhost,
- TargetUserLogin: targetUserLogin,
- Type: membershipChangeType,
+ Target: target,
+ Type: membershipChangeType,
}
- _, err = api.HandleMatrixMembership(ctx, membershipChange)
+ res, err := api.HandleMatrixMembership(ctx, membershipChange)
if err != nil {
log.Err(err).Msg("Failed to handle Matrix membership change")
return EventHandlingResultFailed.WithMSSError(err)
}
- return EventHandlingResultSuccess.WithMSS()
+ didRedirectInvite := membershipChangeType == Invite &&
+ targetGhost != nil &&
+ res != nil &&
+ res.RedirectTo != "" &&
+ res.RedirectTo != targetGhost.ID
+ if didRedirectInvite {
+ log.Debug().
+ Str("orig_id", string(targetGhost.ID)).
+ Str("redirect_id", string(res.RedirectTo)).
+ Msg("Invite was redirected to different ghost")
+ var redirectGhost *Ghost
+ redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo)
+ if err != nil {
+ log.Err(err).Msg("Failed to get redirect target ghost")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ if !isStateRequest {
+ portal.sendRoomMeta(
+ ctx,
+ sender.User.DoublePuppet(ctx),
+ time.UnixMilli(evt.Timestamp),
+ event.StateMember,
+ evt.GetStateKey(),
+ &event.MemberEventContent{
+ Membership: event.MembershipLeave,
+ Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo),
+ },
+ true,
+ nil,
+ )
+ }
+ portal.sendRoomMeta(
+ ctx,
+ sender.User.DoublePuppet(ctx),
+ time.UnixMilli(evt.Timestamp),
+ event.StateMember,
+ redirectGhost.Intent.GetMXID().String(),
+ content,
+ false,
+ nil,
+ )
+ }
+ return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite)
}
func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange {
@@ -1643,13 +2092,27 @@ func (portal *Portal) handleMatrixPowerLevels(
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
+ isStateRequest bool,
) EventHandlingResult {
+ if evt.StateKey == nil || *evt.StateKey != "" {
+ return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
+ }
log := zerolog.Ctx(ctx)
content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
if !ok {
log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
}
+ if content.CreateEvent == nil {
+ ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
+ if ok {
+ var err error
+ content.CreateEvent, err = ars.GetStateEvent(ctx, portal.MXID, event.StateCreate, "")
+ if err != nil {
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err))
+ }
+ }
+ }
api, ok := sender.Client.(PowerLevelHandlingNetworkAPI)
if !ok {
return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported)
@@ -1658,6 +2121,7 @@ func (portal *Portal) handleMatrixPowerLevels(
if evt.Unsigned.PrevContent != nil {
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent)
+ prevContent.CreateEvent = content.CreateEvent
}
plChange := &MatrixPowerLevelChange{
@@ -1670,7 +2134,8 @@ func (portal *Portal) handleMatrixPowerLevels(
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- PrevContent: prevContent,
+ IsStateRequest: isStateRequest,
+ PrevContent: prevContent,
},
Users: make(map[id.UserID]*UserPowerLevelChange),
Events: make(map[string]*SinglePowerLevelChange),
@@ -1716,6 +2181,256 @@ func (portal *Portal) handleMatrixPowerLevels(
return EventHandlingResultSuccess.WithMSS()
}
+func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
+ if evt.StateKey == nil || *evt.StateKey != "" || portal.MXID != evt.RoomID {
+ return EventHandlingResultIgnored
+ }
+ log := *zerolog.Ctx(ctx)
+ sentByBridge := evt.Sender == portal.Bridge.Bot.GetMXID() || portal.Bridge.IsGhostMXID(evt.Sender)
+ var senderUser *User
+ var err error
+ if !sentByBridge {
+ senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender)
+ if err != nil {
+ log.Err(err).Msg("Failed to get tombstone sender user")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ }
+ content, ok := evt.Content.Parsed.(*event.TombstoneEventContent)
+ if !ok {
+ log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
+ }
+ log = log.With().
+ Stringer("replacement_room", content.ReplacementRoom).
+ Logger()
+ if content.ReplacementRoom == "" {
+ log.Info().Msg("Received tombstone with no replacement room, cleaning up portal")
+ err := portal.RemoveMXID(ctx)
+ if err != nil {
+ log.Err(err).Msg("Failed to remove portal MXID")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true)
+ if err != nil {
+ log.Err(err).Msg("Failed to clean up Matrix room")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ return EventHandlingResultSuccess
+ }
+ existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID())
+ if err != nil {
+ log.Err(err).Msg("Failed to get member info of bot in replacement room")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ leaveOnError := func() {
+ if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin {
+ return
+ }
+ log.Debug().Msg("Leaving replacement room with bot after tombstone validation failed")
+ _, err = portal.Bridge.Bot.SendState(
+ ctx,
+ content.ReplacementRoom,
+ event.StateMember,
+ portal.Bridge.Bot.GetMXID().String(),
+ &event.Content{
+ Parsed: &event.MemberEventContent{
+ Membership: event.MembershipLeave,
+ Reason: fmt.Sprintf("Failed to validate tombstone sent by %s from %s", evt.Sender, evt.RoomID),
+ },
+ },
+ time.Time{},
+ )
+ if err != nil {
+ log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed")
+ }
+ }
+ var via []string
+ if senderHS := evt.Sender.Homeserver(); senderHS != "" {
+ via = []string{senderHS}
+ }
+ err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via})
+ if err != nil {
+ log.Err(err).Msg("Failed to join replacement room from tombstone")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ if !sentByBridge && !senderUser.Permissions.Admin {
+ powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom)
+ if err != nil {
+ log.Err(err).Msg("Failed to get power levels in replacement room")
+ leaveOnError()
+ return EventHandlingResultFailed.WithError(err)
+ }
+ if powers.GetUserLevel(evt.Sender) < powers.Invite() {
+ log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room")
+ leaveOnError()
+ return EventHandlingResultIgnored
+ }
+ }
+ err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{
+ DeleteOldRoom: true,
+ FetchInfoVia: senderUser,
+ })
+ if errors.Is(err, ErrTargetRoomIsPortal) {
+ return EventHandlingResultIgnored
+ } else if err != nil {
+ return EventHandlingResultFailed.WithError(err)
+ }
+ return EventHandlingResultSuccess
+}
+
+var ErrTargetRoomIsPortal = errors.New("target room is already a portal")
+var ErrRoomAlreadyExists = errors.New("this portal already has a room")
+
+type UpdateMatrixRoomIDParams struct {
+ SyncDBMetadata func()
+ FailIfMXIDSet bool
+ OverwriteOldPortal bool
+ TombstoneOldRoom bool
+ DeleteOldRoom bool
+
+ RoomCreateAlreadyLocked bool
+
+ FetchInfoVia *User
+ ChatInfo *ChatInfo
+ ChatInfoSource *UserLogin
+}
+
+func (portal *Portal) UpdateMatrixRoomID(
+ ctx context.Context,
+ newRoomID id.RoomID,
+ params UpdateMatrixRoomIDParams,
+) error {
+ if !params.RoomCreateAlreadyLocked {
+ portal.roomCreateLock.Lock()
+ defer portal.roomCreateLock.Unlock()
+ }
+ oldRoom := portal.MXID
+ if oldRoom == newRoomID {
+ return nil
+ } else if oldRoom != "" && params.FailIfMXIDSet {
+ return ErrRoomAlreadyExists
+ }
+ log := zerolog.Ctx(ctx)
+ portal.Bridge.cacheLock.Lock()
+ // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns
+ // and unlock it before return if nothing goes wrong.
+ unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock)
+ defer unlockCacheLock()
+ if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal {
+ log.Warn().Msg("Replacement room is already a portal, ignoring")
+ return ErrTargetRoomIsPortal
+ } else if alreadyExists {
+ log.Debug().Msg("Replacement room is already a portal, overwriting")
+ existingPortal.MXID = ""
+ existingPortal.RoomCreated.Clear()
+ err := existingPortal.Save(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to clear mxid of existing portal: %w", err)
+ }
+ delete(portal.Bridge.portalsByMXID, portal.MXID)
+ }
+ portal.MXID = newRoomID
+ portal.RoomCreated.Set()
+ portal.Bridge.portalsByMXID[portal.MXID] = portal
+ portal.NameSet = false
+ portal.AvatarSet = false
+ portal.TopicSet = false
+ portal.InSpace = false
+ portal.CapState = database.CapabilityState{}
+ portal.lastCapUpdate = time.Time{}
+ if params.SyncDBMetadata != nil {
+ params.SyncDBMetadata()
+ }
+ unlockCacheLock()
+ portal.updateLogger()
+
+ err := portal.Save(ctx)
+ if err != nil {
+ log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID")
+ return err
+ }
+ log.Info().Msg("Successfully followed tombstone and updated portal MXID")
+ err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID")
+ }
+ go portal.addToUserSpaces(ctx)
+ if params.FetchInfoVia != nil {
+ go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia)
+ } else if params.ChatInfo != nil {
+ go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{})
+ } else if params.ChatInfoSource != nil {
+ portal.UpdateCapabilities(ctx, params.ChatInfoSource, true)
+ portal.UpdateBridgeInfo(ctx)
+ }
+ go func() {
+ // TODO this might become unnecessary if UpdateInfo starts taking care of it
+ _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
+ Parsed: &event.ElementFunctionalMembersContent{
+ ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()},
+ },
+ }, time.Time{})
+ if err != nil {
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to set service members in new room")
+ }
+ }
+ }()
+ if params.TombstoneOldRoom && oldRoom != "" {
+ _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{
+ Parsed: &event.TombstoneEventContent{
+ Body: "Room has been replaced.",
+ ReplacementRoom: newRoomID,
+ },
+ }, time.Now())
+ if err != nil {
+ log.Err(err).Msg("Failed to send tombstone event to old room")
+ }
+ }
+ if params.DeleteOldRoom && oldRoom != "" {
+ go func() {
+ err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true)
+ if err != nil {
+ log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID")
+ }
+ }()
+ }
+ return nil
+}
+
+func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) {
+ log := zerolog.Ctx(ctx)
+ logins, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user logins in portal to sync info")
+ return
+ }
+ var preferredLogin *UserLogin
+ for _, login := range logins {
+ if !login.Client.IsLoggedIn() {
+ continue
+ } else if preferredLogin == nil {
+ preferredLogin = login
+ } else if senderUser != nil && login.User == senderUser {
+ preferredLogin = login
+ }
+ }
+ if preferredLogin == nil {
+ log.Warn().Msg("No logins found to sync info")
+ return
+ }
+ info, err := preferredLogin.Client.GetChatInfo(ctx, portal)
+ if err != nil {
+ log.Err(err).Msg("Failed to get chat info")
+ return
+ }
+ log.Info().
+ Str("info_source_login", string(preferredLogin.ID)).
+ Msg("Fetched info to update portal after tombstone")
+ portal.UpdateInfo(ctx, info, preferredLogin, nil, time.Time{})
+}
+
func (portal *Portal) handleMatrixRedaction(
ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event,
) EventHandlingResult {
@@ -1815,7 +2530,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
err = portal.createMatrixRoomInLoop(ctx, source, info, bundle)
if err != nil {
log.Err(err).Msg("Failed to create portal to handle event")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
if evtType == RemoteEventChatResync {
log.Debug().Msg("Not handling chat resync event further as portal was created by it")
@@ -1834,6 +2549,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
switch evtType {
case RemoteEventUnknown:
log.Debug().Msg("Ignoring remote event with type unknown")
+ res = EventHandlingResultIgnored
case RemoteEventMessage, RemoteEventMessageUpsert:
res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage))
case RemoteEventEdit:
@@ -1872,6 +2588,46 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
return
}
+func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) {
+ if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" {
+ return
+ }
+ ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ return
+ }
+ portal.functionalMembersLock.Lock()
+ defer portal.functionalMembersLock.Unlock()
+ var functionalMembers *event.ElementFunctionalMembersContent
+ if portal.functionalMembersCache != nil {
+ functionalMembers = portal.functionalMembersCache
+ } else {
+ evt, err := ars.GetStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "")
+ if err != nil && !errors.Is(err, mautrix.MNotFound) {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get functional members state event")
+ return
+ }
+ functionalMembers = &event.ElementFunctionalMembersContent{}
+ if evt != nil {
+ evtContent, ok := evt.Content.Parsed.(*event.ElementFunctionalMembersContent)
+ if ok && evtContent != nil {
+ functionalMembers = evtContent
+ }
+ }
+ }
+ // TODO what about non-double-puppeted user ghosts?
+ functionalMembers.Add(portal.Bridge.Bot.GetMXID())
+ if functionalMembers.Add(ghost.Intent.GetMXID()) {
+ _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
+ Parsed: functionalMembers,
+ }, time.Time{})
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to update functional members state event")
+ return
+ }
+ }
+}
+
func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) {
var ghost *Ghost
if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID {
@@ -1895,6 +2651,7 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS
return
}
ghost.UpdateInfoIfNecessary(ctx, source, evtType)
+ portal.ensureFunctionalMember(ctx, ghost)
}
if sender.IsFromMe {
intent = source.User.DoublePuppet(ctx)
@@ -2006,7 +2763,7 @@ func (portal *Portal) getRelationMeta(
log.Err(err).Msg("Failed to get last thread message from database")
}
if prevThreadEvent == nil {
- prevThreadEvent = threadRoot
+ prevThreadEvent = ptr.Clone(threadRoot)
}
}
return
@@ -2065,6 +2822,7 @@ func (portal *Portal) sendConvertedMessage(
allSuccess := true
for i, part := range converted.Parts {
portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent)
+ part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent()
dbMessage := &database.Message{
ID: id,
PartID: part.ID,
@@ -2109,13 +2867,14 @@ func (portal *Portal) sendConvertedMessage(
logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database")
allSuccess = false
}
- if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() {
- if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() {
+ if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() {
+ if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() {
converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer)
}
portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: dbMessage.MXID,
+ Timestamp: dbMessage.Timestamp,
DisappearingSetting: converted.Disappear,
})
}
@@ -2204,7 +2963,7 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin,
err = portal.Bridge.DB.Message.Update(ctx, part)
if err != nil {
log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database")
- handleRes = EventHandlingResultFailed
+ handleRes = EventHandlingResultFailed.WithError(err)
}
}
}
@@ -2268,7 +3027,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin
} else {
log.Err(err).Msg("Failed to convert remote message")
portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
}
_, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil)
@@ -2313,7 +3072,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e
existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID)
if err != nil {
log.Err(err).Msg("Failed to get edit target message")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
}
if existing == nil {
@@ -2338,7 +3097,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e
} else if err != nil {
log.Err(err).Msg("Failed to convert remote edit")
portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt))
if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) {
@@ -2487,7 +3246,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User
targetMessage, err := portal.getTargetMessagePart(ctx, evt)
if err != nil {
log.Err(err).Msg("Failed to get target message for reaction")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if targetMessage == nil {
// TODO use deterministic event ID as target if applicable?
log.Warn().Msg("Target message for reaction not found")
@@ -2501,7 +3260,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User
}
if err != nil {
log.Err(err).Msg("Failed to get existing reactions for reaction sync")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction)
for _, existingReaction := range existingReactions {
@@ -2623,7 +3382,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi
targetMessage, err := portal.getTargetMessagePart(ctx, evt)
if err != nil {
log.Err(err).Msg("Failed to get target message for reaction")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if targetMessage == nil {
// TODO use deterministic event ID as target if applicable?
log.Warn().Msg("Target message for reaction not found")
@@ -2633,7 +3392,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi
existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) {
log.Debug().Msg("Ignoring duplicate reaction")
return EventHandlingResultIgnored
@@ -2703,7 +3462,7 @@ func (portal *Portal) sendConvertedReaction(
})
if err != nil {
logContext(log.Err(err)).Msg("Failed to send reaction to Matrix")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
logContext(log.Debug()).
Stringer("event_id", resp.EventID).
@@ -2712,7 +3471,7 @@ func (portal *Portal) sendConvertedReaction(
err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction)
if err != nil {
logContext(log.Err(err)).Msg("Failed to save reaction to database")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
return EventHandlingResultSuccess
}
@@ -2738,7 +3497,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us
targetReaction, err := portal.getTargetReaction(ctx, evt)
if err != nil {
log.Err(err).Msg("Failed to get target reaction for removal")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if targetReaction == nil {
log.Warn().Msg("Target reaction not found")
return EventHandlingResultIgnored
@@ -2762,7 +3521,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us
}, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction})
if err != nil {
log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction)
if err != nil {
@@ -2776,7 +3535,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use
targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage())
if err != nil {
log.Err(err).Msg("Failed to get target message for removal")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if len(targetParts) == 0 {
log.Debug().Msg("Target message not found")
return EventHandlingResultIgnored
@@ -2784,7 +3543,14 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use
onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe)
onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe()
if onlyForMe && portal.Receiver == "" {
- // TODO check if there are other user logins before deleting
+ _, others, err := portal.findOtherLogins(ctx, source)
+ if err != nil {
+ log.Err(err).Msg("Failed to check if portal has other logins")
+ return EventHandlingResultFailed.WithError(err)
+ } else if len(others) > 0 {
+ log.Debug().Msg("Ignoring delete for me event in portal with multiple logins")
+ return EventHandlingResultIgnored
+ }
}
intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove)
@@ -2846,7 +3612,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
if err != nil {
log.Err(err).Str("last_target_id", string(lastTargetID)).
Msg("Failed to get last target message for read receipt")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if lastTarget == nil {
log.Debug().Str("last_target_id", string(lastTargetID)).
Msg("Last target message not found")
@@ -2865,7 +3631,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
if err != nil {
log.Err(err).Str("target_id", string(targetID)).
Msg("Failed to get target message for read receipt")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) {
lastTarget = target
}
@@ -2895,20 +3661,24 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
return evt.Int64("target_stream_order", targetStreamOrder)
}
err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt))
+ if readUpTo.IsZero() {
+ readUpTo = getEventTS(evt)
+ }
} else {
addTargetLog = func(evt *zerolog.Event) *zerolog.Event {
return evt.Stringer("target_mxid", lastTarget.MXID)
}
err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt))
+ readUpTo = lastTarget.Timestamp
}
if err != nil {
addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else {
addTargetLog(log.Debug()).Msg("Bridged read receipt")
}
if sender.IsFromMe {
- portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID)
+ portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo)
}
return EventHandlingResultSuccess
}
@@ -2925,13 +3695,13 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo
err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread())
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
return EventHandlingResultSuccess
}
func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult {
- if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID {
+ if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") {
return EventHandlingResultIgnored
}
intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt)
@@ -2943,7 +3713,7 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U
targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target)
if err != nil {
log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else if len(targetParts) == 0 {
continue
} else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost {
@@ -2978,7 +3748,7 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin,
err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
if timeout == 0 {
portal.currentlyTypingGhosts.Remove(intent.GetMXID())
@@ -2992,7 +3762,7 @@ func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *Us
info, err := evt.GetChatInfoChange(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt))
return EventHandlingResultSuccess
@@ -3030,22 +3800,43 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo
return EventHandlingResultSuccess
}
+func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) {
+ others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
+ if err != nil {
+ return
+ }
+ others = slices.DeleteFunc(others, func(up *database.UserPortal) bool {
+ if up.LoginID == source.ID {
+ ownUP = up
+ return true
+ }
+ return false
+ })
+ return
+}
+
+type childDeleteProxy struct {
+ RemoteChatDeleteWithChildren
+ child networkid.PortalKey
+ done func()
+}
+
+func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context {
+ return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children")
+}
+func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child }
+func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false }
+func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {}
+func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() }
+
func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult {
log := zerolog.Ctx(ctx)
if portal.Receiver == "" && evt.DeleteOnlyForMe() {
- logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
+ ownUP, logins, err := portal.findOtherLogins(ctx, source)
if err != nil {
log.Err(err).Msg("Failed to check if portal has other logins")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
- var ownUP *database.UserPortal
- logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool {
- if up.LoginID == source.ID {
- ownUP = up
- return true
- }
- return false
- })
if len(logins) > 0 {
log.Debug().Msg("Not deleting portal with other logins in remote chat delete event")
if ownUP != nil {
@@ -3066,22 +3857,47 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo
)
if err != nil {
log.Err(err).Msg("Failed to send leave state event for user after remote chat delete")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else {
log.Debug().Msg("Sent leave state event for user after remote chat delete")
return EventHandlingResultSuccess
}
}
}
+ if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace {
+ children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to fetch children to delete")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ log.Debug().
+ Int("portal_count", len(children)).
+ Msg("Deleting child portals before remote chat delete")
+ var wg sync.WaitGroup
+ wg.Add(len(children))
+ for _, child := range children {
+ child.queueEvent(ctx, &portalRemoteEvent{
+ evt: &childDeleteProxy{
+ RemoteChatDeleteWithChildren: childDeleter,
+ child: child.PortalKey,
+ done: wg.Done,
+ },
+ source: source,
+ evtType: RemoteEventChatDelete,
+ })
+ }
+ wg.Wait()
+ log.Debug().Msg("Finished deleting child portals")
+ }
err := portal.Delete(ctx)
if err != nil {
log.Err(err).Msg("Failed to delete portal from database")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
}
err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false)
if err != nil {
log.Err(err).Msg("Failed to delete Matrix room")
- return EventHandlingResultFailed
+ return EventHandlingResultFailed.WithError(err)
} else {
log.Info().Msg("Deleted room after remote chat delete event")
return EventHandlingResultSuccess
@@ -3128,12 +3944,43 @@ type PortalInfo = ChatInfo
type ChatMember struct {
EventSender
Membership event.Membership
- Nickname *string
+ // Per-room nickname for the user. Not yet used.
+ Nickname *string
+ // The power level to set for the user when syncing power levels.
PowerLevel *int
- UserInfo *UserInfo
-
+ // Optional user info to sync the ghost user while updating membership.
+ UserInfo *UserInfo
+ // The user who sent the membership change (user who invited/kicked/banned this user).
+ // Not yet used. Not applicable if Membership is join or knock.
+ MemberSender EventSender
+ // Extra fields to include in the member event.
MemberEventExtra map[string]any
- PrevMembership event.Membership
+ // The expected previous membership. If this doesn't match, the change is ignored.
+ PrevMembership event.Membership
+}
+
+type ChatMemberMap map[networkid.UserID]ChatMember
+
+// Set adds the given entry to this map, overwriting any existing entry with the same Sender field.
+func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap {
+ if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe {
+ return cmm
+ }
+ cmm[member.Sender] = member
+ return cmm
+}
+
+// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists.
+// It returns true if the entry was added, false otherwise.
+func (cmm ChatMemberMap) Add(member ChatMember) bool {
+ if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe {
+ return false
+ }
+ if _, exists := cmm[member.Sender]; exists {
+ return false
+ }
+ cmm[member.Sender] = member
+ return true
}
type ChatMemberList struct {
@@ -3143,6 +3990,10 @@ type ChatMemberList struct {
// Should the bridge call IsThisUser for every member in the list?
// This should be used when SenderLogin can't be filled accurately.
CheckAllLogins bool
+ // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default?
+ // This is recommended for syncs with non-real-time changes.
+ // Real-time changes (e.g. a user joining) should not set this flag set.
+ ExcludeChangesFromTimeline bool
// The total number of members in the chat, regardless of how many of those members are included in MemberMap.
TotalMemberCount int
@@ -3153,7 +4004,7 @@ type ChatMemberList struct {
// Deprecated: Use MemberMap instead to avoid duplicate entries
Members []ChatMember
- MemberMap map[networkid.UserID]ChatMember
+ MemberMap ChatMemberMap
PowerLevels *PowerLevelOverrides
}
@@ -3255,9 +4106,11 @@ type ChatInfo struct {
Disappear *database.DisappearingSetting
ParentID *networkid.PortalID
- UserLocal *UserLocalPortalInfo
+ UserLocal *UserLocalPortalInfo
+ MessageRequest *bool
+ CanBackfill bool
- CanBackfill bool
+ ExcludeChangesFromTimeline bool
ExtraUpdates ExtraUpdater[*Portal]
}
@@ -3291,26 +4144,36 @@ type UserLocalPortalInfo struct {
Tag *event.RoomTag
}
-func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool {
+func (portal *Portal) updateName(
+ ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool,
+) bool {
if portal.Name == name && (portal.NameSet || portal.MXID == "") {
return false
}
portal.Name = name
- portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name})
+ portal.NameSet = portal.sendRoomMeta(
+ ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil,
+ )
return true
}
-func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool {
+func (portal *Portal) updateTopic(
+ ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool,
+) bool {
if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") {
return false
}
portal.Topic = topic
- portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic})
+ portal.TopicSet = portal.sendRoomMeta(
+ ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil,
+ )
return true
}
-func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool {
- if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") {
+func (portal *Portal) updateAvatar(
+ ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool,
+) bool {
+ if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") {
return false
}
portal.AvatarID = avatar.ID
@@ -3326,13 +4189,15 @@ func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender M
portal.AvatarSet = false
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar")
return true
- } else if newHash == portal.AvatarHash && portal.AvatarSet {
+ } else if newHash == portal.AvatarHash && portal.AvatarMXC != "" && portal.AvatarSet {
return true
}
portal.AvatarMXC = newMXC
portal.AvatarHash = newHash
}
- portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC})
+ portal.AvatarSet = portal.sendRoomMeta(
+ ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil,
+ )
return true
}
@@ -3363,10 +4228,11 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) {
Creator: portal.Bridge.Bot.GetMXID(),
Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(),
Channel: event.BridgeInfoSection{
- ID: string(portal.ID),
- DisplayName: portal.Name,
- AvatarURL: portal.AvatarMXC,
- Receiver: string(portal.Receiver),
+ ID: string(portal.ID),
+ DisplayName: portal.Name,
+ AvatarURL: portal.AvatarMXC,
+ Receiver: string(portal.Receiver),
+ MessageRequest: portal.MessageRequest,
// TODO external URL?
},
BeeperRoomTypeV2: string(portal.RoomType),
@@ -3374,6 +4240,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) {
if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM {
bridgeInfo.BeeperRoomType = "dm"
}
+ if bridgeInfo.Protocol.ID == "slackgo" {
+ bridgeInfo.TempSlackRemoteIDMigratedFlag = true
+ bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true
+ }
parent := portal.GetTopLevelParent()
if parent != nil {
bridgeInfo.Network = &event.BridgeInfoSection{
@@ -3395,8 +4265,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) {
return
}
stateKey, bridgeInfo := portal.getBridgeInfo()
- portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo)
- portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo)
+ portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil)
+ portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil)
}
func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool {
@@ -3418,13 +4288,22 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin,
Str("old_id", portal.CapState.ID).
Str("new_id", capID).
Msg("Sending new room capability event")
- success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps)
+ success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil)
if !success {
return false
}
portal.CapState = database.CapabilityState{
Source: source.ID,
ID: capID,
+ Flags: portal.CapState.Flags,
+ }
+ if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) {
+ zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event")
+ success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil)
+ if !success {
+ return false
+ }
+ portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet
}
portal.lastCapUpdate = time.Now()
if implicit {
@@ -3451,15 +4330,27 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri
return
}
-func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool {
+func (portal *Portal) sendRoomMeta(
+ ctx context.Context,
+ sender MatrixAPI,
+ ts time.Time,
+ eventType event.Type,
+ stateKey string,
+ content any,
+ excludeFromTimeline bool,
+ extra map[string]any,
+) bool {
if portal.MXID == "" {
return false
}
- var extra map[string]any
+ if extra == nil {
+ extra = make(map[string]any)
+ }
+ if excludeFromTimeline {
+ extra["com.beeper.exclude_from_timeline"] = true
+ }
if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) {
- extra = map[string]any{
- "fi.mau.implicit_name": true,
- }
+ extra["fi.mau.implicit_name"] = true
}
_, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{
Parsed: content,
@@ -3471,9 +4362,55 @@ func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts tim
Msg("Failed to set room metadata")
return false
}
+ if eventType == event.StateBeeperDisappearingTimer {
+ // TODO remove this debug log at some point
+ zerolog.Ctx(ctx).Debug().
+ Any("content", content).
+ Msg("Sent new disappearing timer event")
+ }
return true
}
+func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) {
+ if !portal.Bridge.Config.RevertFailedStateChanges {
+ return
+ }
+ if evt.GetStateKey() != "" && evt.Type != event.StateMember {
+ return
+ }
+ switch evt.Type {
+ case event.StateRoomName:
+ portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil)
+ case event.StateRoomAvatar:
+ portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil)
+ case event.StateTopic:
+ portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil)
+ case event.StateBeeperDisappearingTimer:
+ portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil)
+ case event.StateMember:
+ var prevContent *event.MemberEventContent
+ var extra map[string]any
+ if evt.Unsigned.PrevContent != nil {
+ _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
+ prevContent = evt.Unsigned.PrevContent.AsMember()
+ newContent := evt.Content.AsMember()
+ if prevContent.Membership == newContent.Membership {
+ return
+ }
+ extra = evt.Unsigned.PrevContent.Raw
+ } else {
+ prevContent = &event.MemberEventContent{Membership: event.MembershipLeave}
+ }
+ if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange {
+ if extra == nil {
+ extra = make(map[string]any)
+ }
+ extra["com.beeper.member_rollback"] = true
+ portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra)
+ }
+ }
+}
+
func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) {
if members == nil {
invite = []id.UserID{source.UserMXID}
@@ -3557,6 +4494,39 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi
return false
}
+func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool {
+ switch rule.JoinRule {
+ case event.JoinRulePublic:
+ return true
+ case event.JoinRuleKnockRestricted, event.JoinRuleRestricted:
+ for _, allow := range rule.Allow {
+ if allow.Type == "fi.mau.spam_checker" {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func (portal *Portal) roomIsPublic(ctx context.Context) bool {
+ mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ return false
+ }
+ evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "")
+ if err != nil {
+ zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public")
+ return false
+ } else if evt == nil {
+ return false
+ }
+ content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent)
+ if !ok {
+ return false
+ }
+ return looksDirectlyJoinable(content)
+}
+
func (portal *Portal) syncParticipants(
ctx context.Context,
members *ChatMemberList,
@@ -3587,6 +4557,12 @@ func (portal *Portal) syncParticipants(
}
delete(currentMembers, portal.Bridge.Bot.GetMXID())
powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower)
+ addExcludeFromTimeline := func(raw map[string]any) {
+ _, hasKey := raw["com.beeper.exclude_from_timeline"]
+ if !hasKey && members.ExcludeChangesFromTimeline {
+ raw["com.beeper.exclude_from_timeline"] = true
+ }
+ }
syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool {
if member.Membership == "" {
member.Membership = event.MembershipJoin
@@ -3616,12 +4592,10 @@ func (portal *Portal) syncParticipants(
Displayname: currentMember.Displayname,
AvatarURL: currentMember.AvatarURL,
}
- wrappedContent := &event.Content{Parsed: content, Raw: maps.Clone(member.MemberEventExtra)}
- if wrappedContent.Raw == nil {
- wrappedContent.Raw = make(map[string]any)
- }
+ wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)}
+ addExcludeFromTimeline(wrappedContent.Raw)
thisEvtSender := sender
- if member.Membership == event.MembershipJoin {
+ if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) {
content.Membership = event.MembershipInvite
if intent != nil {
wrappedContent.Raw["fi.mau.will_auto_accept"] = true
@@ -3651,7 +4625,11 @@ func (portal *Portal) syncParticipants(
currentMember.Membership = event.MembershipLeave
}
}
- _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts)
+ if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID {
+ _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts)
+ } else {
+ _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts)
+ }
if err != nil {
addLogContext(log.Err(err)).
Str("new_membership", string(content.Membership)).
@@ -3664,7 +4642,8 @@ func (portal *Portal) syncParticipants(
if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin {
content.Membership = event.MembershipJoin
- wrappedJoinContent := &event.Content{Parsed: content, Raw: member.MemberEventExtra}
+ wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)}
+ addExcludeFromTimeline(wrappedContent.Raw)
_, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts)
if err != nil {
addLogContext(log.Err(err)).
@@ -3727,7 +4706,7 @@ func (portal *Portal) syncParticipants(
if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan {
continue
}
- if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil {
+ if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) {
continue
}
_, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{
@@ -3737,6 +4716,9 @@ func (portal *Portal) syncParticipants(
Displayname: memberEvt.Displayname,
Reason: "User is not in remote chat",
},
+ Raw: map[string]any{
+ "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline,
+ },
}, time.Now())
if err != nil {
zerolog.Ctx(ctx).Err(err).
@@ -3805,16 +4787,28 @@ func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.M
return content
}
-func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool {
- if setting.Timer == 0 {
- setting.Type = ""
- }
+type UpdateDisappearingSettingOpts struct {
+ Sender MatrixAPI
+ Timestamp time.Time
+ Implicit bool
+ Save bool
+ SendNotice bool
+
+ ExcludeFromTimeline bool
+}
+
+func (portal *Portal) UpdateDisappearingSetting(
+ ctx context.Context,
+ setting database.DisappearingSetting,
+ opts UpdateDisappearingSettingOpts,
+) bool {
+ setting = setting.Normalize()
if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type {
return false
}
portal.Disappear.Type = setting.Type
portal.Disappear.Timer = setting.Timer
- if save {
+ if opts.Save {
err := portal.Save(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting")
@@ -3823,19 +4817,45 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat
if portal.MXID == "" {
return true
}
- content := DisappearingMessageNotice(setting.Timer, implicit)
- if sender == nil {
- sender = portal.Bridge.Bot
+
+ if opts.Sender == nil {
+ opts.Sender = portal.Bridge.Bot
}
- _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{
+ if opts.Timestamp.IsZero() {
+ opts.Timestamp = time.Now()
+ }
+ portal.sendRoomMeta(
+ ctx,
+ opts.Sender,
+ opts.Timestamp,
+ event.StateBeeperDisappearingTimer,
+ "",
+ setting.ToEventContent(),
+ opts.ExcludeFromTimeline,
+ nil,
+ )
+
+ if !opts.SendNotice {
+ return true
+ }
+ content := DisappearingMessageNotice(setting.Timer, opts.Implicit)
+ _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{
Parsed: content,
- }, &MatrixSendExtra{Timestamp: ts})
+ Raw: map[string]any{
+ "com.beeper.action_message": map[string]any{
+ "type": "disappearing_timer",
+ "timer": setting.Timer.Milliseconds(),
+ "timer_type": setting.Type,
+ "implicit": opts.Implicit,
+ },
+ },
+ }, &MatrixSendExtra{Timestamp: opts.Timestamp})
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice")
} else {
zerolog.Ctx(ctx).Debug().
Dur("new_timer", portal.Disappear.Timer).
- Bool("implicit", implicit).
+ Bool("implicit", opts.Implicit).
Msg("Sent disappearing messages notice")
}
return true
@@ -3897,13 +4917,13 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch
return
}
}
- changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed
+ changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed
changed = portal.updateAvatar(ctx, &Avatar{
ID: ghost.AvatarID,
MXC: ghost.AvatarMXC,
Hash: ghost.AvatarHash,
Remove: ghost.AvatarID == "",
- }, nil, time.Time{}) || changed
+ }, nil, time.Time{}, false) || changed
return
}
@@ -3912,28 +4932,36 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us
if info.Name == DefaultChatName {
if portal.NameIsCustom {
portal.NameIsCustom = false
- changed = portal.updateName(ctx, "", sender, ts) || changed
+ changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed
}
} else if info.Name != nil {
portal.NameIsCustom = true
- changed = portal.updateName(ctx, *info.Name, sender, ts) || changed
+ changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed
}
if info.Topic != nil {
- changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed
+ changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed
}
if info.Avatar != nil {
portal.NameIsCustom = true
- changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed
+ changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed
}
if info.Disappear != nil {
- changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed
+ changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{
+ Sender: sender,
+ Timestamp: ts,
+ Implicit: false,
+ Save: false,
+
+ SendNotice: !info.ExcludeChangesFromTimeline,
+ ExcludeFromTimeline: info.ExcludeChangesFromTimeline,
+ }) || changed
}
if info.ParentID != nil {
changed = portal.updateParent(ctx, *info.ParentID, source) || changed
}
if info.JoinRule != nil {
// TODO change detection instead of spamming this every time?
- portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule)
+ portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil)
}
if info.Type != nil && portal.RoomType != *info.Type {
if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) {
@@ -3946,6 +4974,10 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us
portal.RoomType = *info.Type
}
}
+ if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest {
+ changed = true
+ portal.MessageRequest = *info.MessageRequest
+ }
if info.Members != nil && portal.MXID != "" && source != nil {
err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{})
if err != nil {
@@ -3987,6 +5019,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
}
return nil
}
+ if portal.deleted.IsSet() {
+ return ErrPortalIsDeleted
+ }
waiter := make(chan struct{})
closed := false
evt := &portalCreateEvent{
@@ -4004,7 +5039,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
if PortalEventBuffer == 0 {
go portal.queueEvent(ctx, evt)
} else {
- portal.events <- evt
+ select {
+ case portal.events <- evt:
+ case <-portal.deleted.GetChan():
+ return ErrPortalIsDeleted
+ }
}
select {
case <-ctx.Done():
@@ -4015,7 +5054,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
}
func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error {
+ cancellableCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ portal.cancelRoomCreate.CompareAndSwap(nil, &cancel)
portal.roomCreateLock.Lock()
+ portal.cancelRoomCreate.Store(&cancel)
defer portal.roomCreateLock.Unlock()
if portal.MXID != "" {
if source != nil {
@@ -4026,6 +5069,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
log := zerolog.Ctx(ctx).With().
Str("action", "create matrix room").
Logger()
+ cancellableCtx = log.WithContext(cancellableCtx)
ctx = log.WithContext(ctx)
log.Info().Msg("Creating Matrix room")
@@ -4034,16 +5078,16 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
if info != nil {
log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info")
}
- info, err = source.Client.GetChatInfo(ctx, portal)
+ info, err = source.Client.GetChatInfo(cancellableCtx, portal)
if err != nil {
log.Err(err).Msg("Failed to update portal info for creation")
return err
}
}
- portal.UpdateInfo(ctx, info, source, nil, time.Time{})
- if ctx.Err() != nil {
- return ctx.Err()
+ portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{})
+ if cancellableCtx.Err() != nil {
+ return cancellableCtx.Err()
}
powerLevels := &event.PowerLevelsEventContent{
@@ -4056,7 +5100,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
portal.Bridge.Bot.GetMXID(): 9001,
},
}
- initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels)
+ initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels)
if err != nil {
log.Err(err).Msg("Failed to process participant list for portal creation")
return err
@@ -4065,15 +5109,12 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
req := mautrix.ReqCreateRoom{
Visibility: "private",
- Name: portal.Name,
- Topic: portal.Topic,
CreationContent: make(map[string]any),
InitialState: make([]*event.Event, 0, 6),
Preset: "private_chat",
IsDirect: portal.RoomType == database.RoomTypeDM,
PowerLevelOverride: powerLevels,
BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey),
- RoomVersion: event.RoomV11,
}
autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites
if autoJoinInvites {
@@ -4086,7 +5127,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
req.CreationContent["type"] = event.RoomTypeSpace
}
bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo()
- roomFeatures := source.Client.GetCapabilities(ctx, portal)
+ roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal)
portal.CapState = database.CapabilityState{
Source: source.ID,
ID: roomFeatures.GetID(),
@@ -4109,19 +5150,47 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
StateKey: &bridgeInfoStateKey,
Type: event.StateBeeperRoomFeatures,
Content: event.Content{Parsed: roomFeatures},
+ }, &event.Event{
+ Type: event.StateTopic,
+ Content: event.Content{
+ Parsed: &event.TopicEventContent{Topic: portal.Topic},
+ Raw: map[string]any{
+ "com.beeper.exclude_from_timeline": true,
+ },
+ },
})
- if req.Topic == "" {
- // Add explicit topic event if topic is empty to ensure the event is set.
- // This ensures that there won't be an extra event later if PUT /state/... is called.
+ if roomFeatures.DisappearingTimer != nil {
req.InitialState = append(req.InitialState, &event.Event{
- Type: event.StateTopic,
- Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}},
+ Type: event.StateBeeperDisappearingTimer,
+ Content: event.Content{
+ Parsed: portal.Disappear.ToEventContent(),
+ Raw: map[string]any{
+ "com.beeper.exclude_from_timeline": true,
+ },
+ },
+ })
+ portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet
+ }
+ if portal.Name != "" {
+ req.InitialState = append(req.InitialState, &event.Event{
+ Type: event.StateRoomName,
+ Content: event.Content{
+ Parsed: &event.RoomNameEventContent{Name: portal.Name},
+ Raw: map[string]any{
+ "com.beeper.exclude_from_timeline": true,
+ },
+ },
})
}
if portal.AvatarMXC != "" {
req.InitialState = append(req.InitialState, &event.Event{
- Type: event.StateRoomAvatar,
- Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}},
+ Type: event.StateRoomAvatar,
+ Content: event.Content{
+ Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC},
+ Raw: map[string]any{
+ "com.beeper.exclude_from_timeline": true,
+ },
+ },
})
}
if portal.Parent != nil && portal.Parent.MXID != "" {
@@ -4140,6 +5209,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
Content: event.Content{Parsed: info.JoinRule},
})
}
+ if cancellableCtx.Err() != nil {
+ return cancellableCtx.Err()
+ }
roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req)
if err != nil {
log.Err(err).Msg("Failed to create Matrix room")
@@ -4150,6 +5222,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
portal.TopicSet = true
portal.NameSet = true
portal.MXID = roomID
+ portal.RoomCreated.Set()
portal.Bridge.cacheLock.Lock()
portal.Bridge.portalsByMXID[roomID] = portal
portal.Bridge.cacheLock.Unlock()
@@ -4196,42 +5269,55 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
}
}
}
- if portal.Parent == nil {
- if portal.Receiver != "" {
- login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver)
- if login != nil {
- up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to get user portal to add portal to spaces")
- } else {
- login.inPortalCache.Remove(portal.PortalKey)
- go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
- }
- }
- } else {
- userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces")
- } else {
- for _, up := range userPortals {
- login := portal.Bridge.GetCachedUserLoginByID(up.LoginID)
- if login != nil {
- login.inPortalCache.Remove(portal.PortalKey)
- go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
- }
- }
- }
- }
- }
- if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background {
+ portal.addToUserSpaces(ctx)
+ if info.CanBackfill &&
+ portal.Bridge.Config.Backfill.Enabled &&
+ portal.RoomType != database.RoomTypeSpace &&
+ !portal.Bridge.Background {
portal.doForwardBackfill(ctx, source, nil, backfillBundle)
}
return nil
}
+func (portal *Portal) addToUserSpaces(ctx context.Context) {
+ if portal.Parent != nil {
+ return
+ }
+ log := zerolog.Ctx(ctx)
+ withoutCancelCtx := log.WithContext(portal.Bridge.BackgroundCtx)
+ if portal.Receiver != "" {
+ login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver)
+ if login != nil {
+ up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user portal to add portal to spaces")
+ } else {
+ login.inPortalCache.Remove(portal.PortalKey)
+ go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
+ }
+ }
+ } else {
+ userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces")
+ } else {
+ for _, up := range userPortals {
+ login := portal.Bridge.GetCachedUserLoginByID(up.LoginID)
+ if login != nil {
+ login.inPortalCache.Remove(portal.PortalKey)
+ go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
+ }
+ }
+ }
+ }
+}
+
func (portal *Portal) Delete(ctx context.Context) error {
+ if portal.deleted.IsSet() {
+ return nil
+ }
portal.removeInPortalCache(ctx)
- err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
+ err := portal.safeDBDelete(ctx)
if err != nil {
return err
}
@@ -4241,11 +5327,21 @@ func (portal *Portal) Delete(ctx context.Context) error {
return nil
}
+func (portal *Portal) safeDBDelete(ctx context.Context) error {
+ err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey)
+ if err != nil {
+ return fmt.Errorf("failed to delete messages in portal: %w", err)
+ }
+ // TODO delete child portals?
+ return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
+}
+
func (portal *Portal) RemoveMXID(ctx context.Context) error {
if portal.MXID == "" {
return nil
}
portal.MXID = ""
+ portal.RoomCreated.Clear()
err := portal.Save(ctx)
if err != nil {
return err
@@ -4278,8 +5374,10 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) {
}
func (portal *Portal) unlockedDelete(ctx context.Context) error {
- // TODO delete child portals?
- err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
+ if portal.deleted.IsSet() {
+ return nil
+ }
+ err := portal.safeDBDelete(ctx)
if err != nil {
return err
}
@@ -4288,10 +5386,14 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error {
}
func (portal *Portal) unlockedDeleteCache() {
+ if portal.deleted.IsSet() {
+ return
+ }
delete(portal.Bridge.portalsByKey, portal.PortalKey)
if portal.MXID != "" {
delete(portal.Bridge.portalsByMXID, portal.MXID)
}
+ portal.deleted.Set()
if portal.events != nil {
// TODO there's a small risk of this racing with a queueEvent call
close(portal.events)
@@ -4303,6 +5405,9 @@ func (portal *Portal) Save(ctx context.Context) error {
}
func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error {
+ if portal.Receiver != "" && relay.ID != portal.Receiver {
+ return fmt.Errorf("can't set non-receiver login as relay")
+ }
portal.Relay = relay
if relay == nil {
portal.RelayLoginID = ""
diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go
index 9883fb12..879f07ae 100644
--- a/bridgev2/portalbackfill.go
+++ b/bridgev2/portalbackfill.go
@@ -194,6 +194,9 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t
if err != nil {
log.Err(err).Msg("Failed to get last thread message")
return
+ } else if anchorMessage == nil {
+ log.Warn().Msg("No messages found in thread?")
+ return
}
resp := portal.fetchThreadBackfill(ctx, source, anchorMessage)
if resp != nil {
@@ -339,6 +342,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
for i, part := range msg.Parts {
partIDs = append(partIDs, part.ID)
portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent)
+ part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent()
evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID)
dbMessage := &database.Message{
ID: msg.ID,
@@ -379,19 +383,23 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
prevThreadEvent.MXID = evtID
out.PrevThreadEvents[*msg.ThreadRoot] = evtID
}
- if msg.Disappear.Type != database.DisappearingTypeNone {
- if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
+ if msg.Disappear.Type != event.DisappearingTypeNone {
+ if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer)
}
out.Disappear = append(out.Disappear, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: evtID,
+ Timestamp: msg.Timestamp,
DisappearingSetting: msg.Disappear,
})
}
}
slices.Sort(partIDs)
for _, reaction := range msg.Reactions {
+ if reaction == nil {
+ continue
+ }
reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove)
if !ok {
continue
@@ -402,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
if reaction.Timestamp.IsZero() {
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
}
+ //lint:ignore SA4006 it's a todo
targetPart, ok := partMap[*reaction.TargetPart]
if !ok {
// TODO warning log and/or skip reaction?
diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go
index e82c481a..4c7e2447 100644
--- a/bridgev2/portalinternal.go
+++ b/bridgev2/portalinternal.go
@@ -37,8 +37,8 @@ func (portal *PortalInternals) EventLoop() {
(*Portal)(portal).eventLoop()
}
-func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) {
- return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt)
+func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) {
+ return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt)
}
func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context {
@@ -49,6 +49,10 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any
(*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback)
}
+func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
+ return (*Portal)(portal).unwrapBeeperSendState(ctx, evt)
+}
+
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) {
(*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID)
}
@@ -61,8 +65,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
}
-func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
+func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest)
}
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -73,6 +77,10 @@ func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user
(*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt)
}
+func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) {
+ (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal)
+}
+
func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixTyping(ctx, evt)
}
@@ -117,12 +125,24 @@ func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.User
return (*Portal)(portal).getTargetUser(ctx, userID)
}
-func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt)
+func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt)
}
-func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt)
+func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest)
+}
+
+func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest)
+}
+
+func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixTombstone(ctx, evt)
+}
+
+func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) {
+ (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser)
}
func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
@@ -133,6 +153,10 @@ func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *Us
return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt)
}
+func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) {
+ (*Portal)(portal).ensureFunctionalMember(ctx, ghost)
+}
+
func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) {
return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType)
}
@@ -233,6 +257,10 @@ func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, sourc
return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt)
}
+func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) {
+ return (*Portal)(portal).findOtherLogins(ctx, source)
+}
+
func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult {
return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt)
}
@@ -241,16 +269,16 @@ func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source
return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill)
}
-func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool {
- return (*Portal)(portal).updateName(ctx, name, sender, ts)
+func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
+ return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline)
}
-func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool {
- return (*Portal)(portal).updateTopic(ctx, topic, sender, ts)
+func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
+ return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline)
}
-func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool {
- return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts)
+func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
+ return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline)
}
func (portal *PortalInternals) GetBridgeInfoStateKey() string {
@@ -265,8 +293,12 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen
return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts)
}
-func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool {
- return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content)
+func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool {
+ return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra)
+}
+
+func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) {
+ (*Portal)(portal).revertRoomMeta(ctx, evt)
}
func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) {
@@ -277,6 +309,10 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha
return (*Portal)(portal).updateOtherUser(ctx, members)
}
+func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool {
+ return (*Portal)(portal).roomIsPublic(ctx)
+}
+
func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error {
return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts)
}
@@ -297,6 +333,10 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc
return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle)
}
+func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) {
+ (*Portal)(portal).addToUserSpaces(ctx)
+}
+
func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) {
(*Portal)(portal).removeInPortalCache(ctx)
}
@@ -360,7 +400,3 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save
func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error {
return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove)
}
-
-func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
- return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID)
-}
diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go
index a25fe820..c976d97c 100644
--- a/bridgev2/portalreid.go
+++ b/bridgev2/portalreid.go
@@ -32,21 +32,40 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
if source == target {
return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same")
}
- log := zerolog.Ctx(ctx)
- log.Debug().Msg("Re-ID'ing portal")
+ log := zerolog.Ctx(ctx).With().
+ Str("action", "re-id portal").
+ Stringer("source_portal_key", source).
+ Stringer("target_portal_key", target).
+ Logger()
+ ctx = log.WithContext(ctx)
defer func() {
log.Debug().Msg("Finished handling portal re-ID")
}()
- br.cacheLock.Lock()
- defer br.cacheLock.Unlock()
- sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true)
+ acquireCacheLock := func() {
+ if !br.cacheLock.TryLock() {
+ log.Debug().Msg("Waiting for global cache lock")
+ br.cacheLock.Lock()
+ log.Debug().Msg("Acquired global cache lock after waiting")
+ } else {
+ log.Trace().Msg("Acquired global cache lock without waiting")
+ }
+ }
+ log.Debug().Msg("Re-ID'ing portal")
+ sourcePortal, err := br.GetExistingPortalByKey(ctx, source)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err)
} else if sourcePortal == nil {
log.Debug().Msg("Source portal not found, re-ID is no-op")
return ReIDResultNoOp, nil, nil
}
- sourcePortal.roomCreateLock.Lock()
+ if !sourcePortal.roomCreateLock.TryLock() {
+ if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
+ (*cancelCreate)()
+ }
+ log.Debug().Msg("Waiting for source portal room creation lock")
+ sourcePortal.roomCreateLock.Lock()
+ log.Debug().Msg("Acquired source portal room creation lock after waiting")
+ }
defer sourcePortal.roomCreateLock.Unlock()
if sourcePortal.MXID == "" {
log.Info().Msg("Source portal doesn't have Matrix room, deleting row")
@@ -59,22 +78,37 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Stringer("source_portal_mxid", sourcePortal.MXID)
})
+
+ acquireCacheLock()
targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true)
if err != nil {
+ br.cacheLock.Unlock()
return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err)
}
if targetPortal == nil {
log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal")
err = sourcePortal.unlockedReID(ctx, target)
+ br.cacheLock.Unlock()
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err)
}
return ReIDResultSourceReIDd, sourcePortal, nil
}
- targetPortal.roomCreateLock.Lock()
+ br.cacheLock.Unlock()
+
+ if !targetPortal.roomCreateLock.TryLock() {
+ if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
+ (*cancelCreate)()
+ }
+ log.Debug().Msg("Waiting for target portal room creation lock")
+ targetPortal.roomCreateLock.Lock()
+ log.Debug().Msg("Acquired target portal room creation lock after waiting")
+ }
defer targetPortal.roomCreateLock.Unlock()
if targetPortal.MXID == "" {
log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal")
+ acquireCacheLock()
+ defer br.cacheLock.Unlock()
err = targetPortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err)
@@ -89,6 +123,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
return c.Stringer("target_portal_mxid", targetPortal.MXID)
})
log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal")
+ sourcePortal.removeInPortalCache(ctx)
+ acquireCacheLock()
+ defer br.cacheLock.Unlock()
err = sourcePortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err)
@@ -96,7 +133,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
go func() {
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
Parsed: &event.TombstoneEventContent{
- Body: fmt.Sprintf("This room has been merged"),
+ Body: "This room has been merged",
ReplacementRoom: targetPortal.MXID,
},
}, time.Now())
diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go
new file mode 100644
index 00000000..72bacaff
--- /dev/null
+++ b/bridgev2/provisionutil/creategroup.go
@@ -0,0 +1,149 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package provisionutil
+
+import (
+ "context"
+
+ "github.com/rs/zerolog"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/bridgev2"
+ "maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type RespCreateGroup struct {
+ ID networkid.PortalID `json:"id"`
+ MXID id.RoomID `json:"mxid"`
+ Portal *bridgev2.Portal `json:"-"`
+
+ FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"`
+}
+
+func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) {
+ api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI)
+ if !ok {
+ return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups"))
+ }
+ zerolog.Ctx(ctx).Debug().
+ Any("create_params", params).
+ Msg("Creating group chat on remote network")
+ caps := login.Bridge.Network.GetCapabilities()
+ typeSpec, validType := caps.Provisioning.GroupCreation[params.Type]
+ if !validType {
+ return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type))
+ }
+ if len(params.Participants) < typeSpec.Participants.MinLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength))
+ } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength))
+ }
+ userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
+ for i, participant := range params.Participants {
+ parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant))
+ if ok {
+ participant = parsedParticipant
+ params.Participants[i] = participant
+ }
+ if !typeSpec.Participants.SkipIdentifierValidation {
+ if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant))
+ }
+ }
+ if api.IsThisUser(ctx, participant) {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant))
+ }
+ }
+ if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required"))
+ } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength))
+ } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength))
+ }
+ if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required"))
+ }
+ if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required"))
+ } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength))
+ } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength))
+ }
+ if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required"))
+ } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer"))
+ }
+ if params.Username == "" && typeSpec.Username.Required {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required"))
+ } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength))
+ } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength))
+ }
+ if params.Parent == nil && typeSpec.Parent.Required {
+ return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required"))
+ }
+ resp, err := api.CreateGroup(ctx, params)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to create group")
+ return nil, err
+ }
+ if resp.PortalKey.IsEmpty() {
+ return nil, ErrNoPortalKey
+ }
+ zerolog.Ctx(ctx).Debug().
+ Object("portal_key", resp.PortalKey).
+ Msg("Successfully created group on remote network")
+ if resp.Portal == nil {
+ resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
+ return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
+ }
+ }
+ if resp.Portal.MXID == "" {
+ err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
+ return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
+ }
+ }
+ for key, fp := range resp.FailedParticipants {
+ if fp.InviteEventType == "" {
+ fp.InviteEventType = event.EventMessage.Type
+ }
+ if fp.UserMXID == "" {
+ ghost, err := login.Bridge.GetGhostByID(ctx, key)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant")
+ } else if ghost != nil {
+ fp.UserMXID = ghost.Intent.GetMXID()
+ }
+ }
+ if fp.DMRoomMXID == "" {
+ portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant")
+ } else if portal != nil {
+ fp.DMRoomMXID = portal.MXID
+ }
+ }
+ }
+ return &RespCreateGroup{
+ ID: resp.Portal.ID,
+ MXID: resp.Portal.MXID,
+ Portal: resp.Portal,
+
+ FailedParticipants: resp.FailedParticipants,
+ }, nil
+}
diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go
new file mode 100644
index 00000000..ce163e67
--- /dev/null
+++ b/bridgev2/provisionutil/listcontacts.go
@@ -0,0 +1,98 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package provisionutil
+
+import (
+ "context"
+
+ "github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/bridgev2"
+)
+
+type RespGetContactList struct {
+ Contacts []*RespResolveIdentifier `json:"contacts"`
+}
+
+type RespSearchUsers struct {
+ Results []*RespResolveIdentifier `json:"results"`
+}
+
+func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) {
+ api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
+ if !ok {
+ return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts"))
+ }
+ resp, err := api.GetContactList(ctx)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
+ return nil, err
+ }
+ return &RespGetContactList{
+ Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false),
+ }, nil
+}
+
+func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) {
+ api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
+ if !ok {
+ return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users"))
+ }
+ resp, err := api.SearchUsers(ctx, query)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
+ return nil, err
+ }
+ return &RespSearchUsers{
+ Results: processResolveIdentifiers(ctx, login.Bridge, resp, true),
+ }, nil
+}
+
+func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) {
+ apiResp = make([]*RespResolveIdentifier, len(resp))
+ for i, contact := range resp {
+ apiContact := &RespResolveIdentifier{
+ ID: contact.UserID,
+ }
+ apiResp[i] = apiContact
+ if contact.UserInfo != nil {
+ if contact.UserInfo.Name != nil {
+ apiContact.Name = *contact.UserInfo.Name
+ }
+ if contact.UserInfo.Identifiers != nil {
+ apiContact.Identifiers = contact.UserInfo.Identifiers
+ }
+ }
+ if contact.Ghost != nil {
+ if syncInfo && contact.UserInfo != nil {
+ contact.Ghost.UpdateInfo(ctx, contact.UserInfo)
+ }
+ if contact.Ghost.Name != "" {
+ apiContact.Name = contact.Ghost.Name
+ }
+ if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
+ apiContact.Identifiers = contact.Ghost.Identifiers
+ }
+ apiContact.AvatarURL = contact.Ghost.AvatarMXC
+ apiContact.MXID = contact.Ghost.Intent.GetMXID()
+ }
+ if contact.Chat != nil {
+ if contact.Chat.Portal == nil {
+ var err error
+ contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
+ }
+ }
+ if contact.Chat.Portal != nil {
+ apiContact.DMRoomID = contact.Chat.Portal.MXID
+ }
+ }
+ }
+ return
+}
diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go
new file mode 100644
index 00000000..cfc388d0
--- /dev/null
+++ b/bridgev2/provisionutil/resolveidentifier.go
@@ -0,0 +1,125 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package provisionutil
+
+import (
+ "context"
+ "errors"
+
+ "github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/bridgev2"
+ "maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/id"
+)
+
+type RespResolveIdentifier struct {
+ ID networkid.UserID `json:"id"`
+ Name string `json:"name,omitempty"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Identifiers []string `json:"identifiers,omitempty"`
+ MXID id.UserID `json:"mxid,omitempty"`
+ DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
+
+ Portal *bridgev2.Portal `json:"-"`
+ Ghost *bridgev2.Ghost `json:"-"`
+ JustCreated bool `json:"-"`
+}
+
+var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request")
+
+func ResolveIdentifier(
+ ctx context.Context,
+ login *bridgev2.UserLogin,
+ identifier string,
+ createChat bool,
+) (*RespResolveIdentifier, error) {
+ api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
+ if !ok {
+ return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers"))
+ }
+ var resp *bridgev2.ResolveIdentifierResponse
+ parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier))
+ validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
+ if ok && (!vOK || validator.ValidateUserID(parsedUserID)) {
+ ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID")
+ return nil, err
+ }
+ resp = &bridgev2.ResolveIdentifierResponse{
+ Ghost: ghost,
+ UserID: parsedUserID,
+ }
+ gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI)
+ if ok && createChat {
+ resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat")
+ return nil, err
+ }
+ } else if createChat || ghost.Name == "" {
+ zerolog.Ctx(ctx).Debug().
+ Bool("create_chat", createChat).
+ Bool("has_name", ghost.Name != "").
+ Msg("Falling back to resolving identifier")
+ resp = nil
+ identifier = string(parsedUserID)
+ }
+ }
+ if resp == nil {
+ var err error
+ resp, err = api.ResolveIdentifier(ctx, identifier, createChat)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier")
+ return nil, err
+ } else if resp == nil {
+ return nil, nil
+ }
+ }
+ apiResp := &RespResolveIdentifier{
+ ID: resp.UserID,
+ Ghost: resp.Ghost,
+ }
+ if resp.Ghost != nil {
+ if resp.UserInfo != nil {
+ resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
+ }
+ apiResp.Name = resp.Ghost.Name
+ apiResp.AvatarURL = resp.Ghost.AvatarMXC
+ apiResp.Identifiers = resp.Ghost.Identifiers
+ apiResp.MXID = resp.Ghost.Intent.GetMXID()
+ } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
+ apiResp.Name = *resp.UserInfo.Name
+ }
+ if resp.Chat != nil {
+ if resp.Chat.PortalKey.IsEmpty() {
+ return nil, ErrNoPortalKey
+ }
+ if resp.Chat.Portal == nil {
+ var err error
+ resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
+ return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
+ }
+ }
+ resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID)
+ if createChat && resp.Chat.Portal.MXID == "" {
+ apiResp.JustCreated = true
+ err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
+ return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
+ }
+ }
+ apiResp.Portal = resp.Chat.Portal
+ apiResp.DMRoomID = resp.Chat.Portal.MXID
+ }
+ return apiResp, nil
+}
diff --git a/bridgev2/queue.go b/bridgev2/queue.go
index 04d982b5..3775c825 100644
--- a/bridgev2/queue.go
+++ b/bridgev2/queue.go
@@ -63,6 +63,13 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve
return true
}
+var (
+ ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
+ ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
+ ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage())
+ ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage()
+)
+
func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult {
// TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands
@@ -78,13 +85,11 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
return EventHandlingResultFailed
} else if sender == nil {
log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event")
- status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
- br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
+ br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
return EventHandlingResultFailed
} else if !sender.Permissions.SendEvents {
if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") {
- status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
- br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
+ br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt))
}
return EventHandlingResultIgnored
} else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") {
@@ -92,8 +97,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
}
} else if evt.Type.Class != event.EphemeralEventType {
log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event")
- status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
- br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
+ br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
return EventHandlingResultIgnored
}
if evt.Type == event.EventMessage && sender != nil {
@@ -102,11 +106,10 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
msg.RemovePerMessageProfileFallback()
if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom {
if !sender.Permissions.Commands {
- status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
- br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
+ br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt))
return EventHandlingResultIgnored
}
- br.Commands.Handle(
+ go br.Commands.Handle(
ctx,
evt.RoomID,
evt.ID,
@@ -114,7 +117,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "),
msg.RelatesTo.GetReplyTo(),
)
- return EventHandlingResultSuccess
+ return EventHandlingResultQueued
}
}
if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil {
@@ -157,10 +160,27 @@ type EventHandlingResult struct {
Ignored bool
Queued bool
+ SkipStateEcho bool
+
// Error is an optional reason for failure. It is not required, Success may be false even without a specific error.
Error error
// Whether the Error should be sent as a MSS event.
SendMSS bool
+
+ // EventID from the network
+ EventID id.EventID
+ // Stream order from the network
+ StreamOrder int64
+}
+
+func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult {
+ ehr.EventID = id
+ return ehr
+}
+
+func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult {
+ ehr.StreamOrder = order
+ return ehr
}
func (ehr EventHandlingResult) WithError(err error) EventHandlingResult {
@@ -177,6 +197,11 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult {
return ehr
}
+func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult {
+ ehr.SkipStateEcho = skip
+ return ehr
+}
+
func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult {
if err == nil {
return ehr
@@ -195,7 +220,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult {
return ul.Bridge.QueueRemoteEvent(ul, evt)
}
-func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) {
+func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult {
log := login.Log
ctx := log.WithContext(br.BackgroundCtx)
maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver)
@@ -211,14 +236,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res Event
if err != nil {
log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain).
Msg("Failed to get portal to handle remote event")
- return
+ return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err))
} else if portal == nil {
log.Warn().
Stringer("event_type", evt.GetType()).
Object("portal_key", key).
Bool("uncertain_receiver", isUncertain).
Msg("Portal not found to handle remote event")
- return
+ return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler)
}
// TODO put this in a better place, and maybe cache to avoid constant db queries
login.MarkInPortal(ctx, portal)
diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go
index c725141b..56e3a6b1 100644
--- a/bridgev2/simplevent/chat.go
+++ b/bridgev2/simplevent/chat.go
@@ -65,14 +65,19 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal)
type ChatDelete struct {
EventMeta
OnlyForMe bool
+ Children bool
}
-var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil)
+var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil)
func (evt *ChatDelete) DeleteOnlyForMe() bool {
return evt.OnlyForMe
}
+func (evt *ChatDelete) DeleteChildren() bool {
+ return evt.Children
+}
+
// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange].
type ChatInfoChange struct {
EventMeta
diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go
index f648ab12..f8f8d7e1 100644
--- a/bridgev2/simplevent/message.go
+++ b/bridgev2/simplevent/message.go
@@ -59,6 +59,41 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID {
return evt.TransactionID
}
+// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data.
+type PreConvertedMessage struct {
+ EventMeta
+ Data *bridgev2.ConvertedMessage
+ ID networkid.MessageID
+ TransactionID networkid.TransactionID
+
+ HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error)
+}
+
+var (
+ _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil)
+ _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil)
+ _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil)
+)
+
+func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) {
+ return evt.Data, nil
+}
+
+func (evt *PreConvertedMessage) GetID() networkid.MessageID {
+ return evt.ID
+}
+
+func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID {
+ return evt.TransactionID
+}
+
+func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) {
+ if evt.HandleExistingFunc == nil {
+ return bridgev2.UpsertResult{}, nil
+ }
+ return evt.HandleExistingFunc(ctx, portal, intent, existing)
+}
+
type MessageRemove struct {
EventMeta
diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go
index 8aa91866..449a8773 100644
--- a/bridgev2/simplevent/meta.go
+++ b/bridgev2/simplevent/meta.go
@@ -101,6 +101,18 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E
return evt
}
+func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta {
+ origFunc := evt.LogContext
+ if origFunc == nil {
+ evt.LogContext = f
+ return evt
+ }
+ evt.LogContext = func(c zerolog.Context) zerolog.Context {
+ return f(origFunc(c))
+ }
+ return evt
+}
+
func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta {
evt.PortalKey = p
return evt
diff --git a/bridgev2/space.go b/bridgev2/space.go
index ccb74b26..2ca2bce3 100644
--- a/bridgev2/space.go
+++ b/bridgev2/space.go
@@ -164,14 +164,17 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) {
ul.UserMXID: 50,
},
},
- RoomVersion: event.RoomV11,
- Invite: []id.UserID{ul.UserMXID},
+ Invite: []id.UserID{ul.UserMXID},
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{ul.UserMXID}
// TODO remove this after initial_members is supported in hungryserv
req.BeeperAutoJoinInvites = true
}
+ pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI)
+ if ok {
+ pfc.CustomizePersonalFilteringSpace(req)
+ }
ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to create space room: %w", err)
diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go
index 01a235a0..5925dd4f 100644
--- a/bridgev2/status/bridgestate.go
+++ b/bridgev2/status/bridgestate.go
@@ -19,9 +19,10 @@ import (
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -87,6 +88,8 @@ type RemoteProfile struct {
Username string `json:"username,omitempty"`
Name string `json:"name,omitempty"`
Avatar id.ContentURIString `json:"avatar,omitempty"`
+
+ AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"`
}
func coalesce[T ~string](a, b T) T {
@@ -102,11 +105,14 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile {
other.Username = coalesce(rp.Username, other.Username)
other.Name = coalesce(rp.Name, other.Name)
other.Avatar = coalesce(rp.Avatar, other.Avatar)
+ if rp.AvatarFile != nil {
+ other.AvatarFile = rp.AvatarFile
+ }
return other
}
-func (rp *RemoteProfile) IsEmpty() bool {
- return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "")
+func (rp *RemoteProfile) IsZero() bool {
+ return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil)
}
type BridgeState struct {
@@ -120,10 +126,10 @@ type BridgeState struct {
UserAction BridgeStateUserAction `json:"user_action,omitempty"`
- UserID id.UserID `json:"user_id,omitempty"`
- RemoteID string `json:"remote_id,omitempty"`
- RemoteName string `json:"remote_name,omitempty"`
- RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"`
+ UserID id.UserID `json:"user_id,omitempty"`
+ RemoteID networkid.UserLoginID `json:"remote_id,omitempty"`
+ RemoteName string `json:"remote_name,omitempty"`
+ RemoteProfile RemoteProfile `json:"remote_profile,omitzero"`
Reason string `json:"reason,omitempty"`
Info map[string]interface{} `json:"info,omitempty"`
@@ -203,7 +209,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool {
pong.StateEvent == newPong.StateEvent &&
pong.RemoteName == newPong.RemoteName &&
pong.UserAction == newPong.UserAction &&
- ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) &&
+ pong.RemoteProfile == newPong.RemoteProfile &&
pong.Error == newPong.Error &&
maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) &&
pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now())
diff --git a/bridgev2/status/messagecheckpoint.go b/bridgev2/status/messagecheckpoint.go
index ea859b84..b3c05f4f 100644
--- a/bridgev2/status/messagecheckpoint.go
+++ b/bridgev2/status/messagecheckpoint.go
@@ -169,13 +169,13 @@ type CheckpointsJSON struct {
Checkpoints []*MessageCheckpoint `json:"checkpoints"`
}
-func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error {
+func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpoint string, token string) error {
var body bytes.Buffer
if err := json.NewEncoder(&body).Encode(cj); err != nil {
return fmt.Errorf("failed to encode message checkpoint JSON: %w", err)
}
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &body)
if err != nil {
@@ -186,7 +186,10 @@ func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error {
req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (checkpoint sender)")
req.Header.Set("Content-Type", "application/json")
- resp, err := http.DefaultClient.Do(req)
+ if cli == nil {
+ cli = http.DefaultClient
+ }
+ resp, err := cli.Do(req)
if err != nil {
return mautrix.HTTPError{
Request: req,
diff --git a/bridgev2/user.go b/bridgev2/user.go
index 350cecd1..9a7896d6 100644
--- a/bridgev2/user.go
+++ b/bridgev2/user.go
@@ -176,6 +176,10 @@ func (user *User) GetUserLogins() []*UserLogin {
return maps.Values(user.logins)
}
+func (user *User) HasTooManyLogins() bool {
+ return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins
+}
+
func (user *User) GetFormattedUserLogins() string {
user.Bridge.cacheLock.Lock()
logins := make([]string, len(user.logins))
@@ -225,9 +229,8 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) {
user.MXID: 50,
},
},
- RoomVersion: event.RoomV11,
- Invite: []id.UserID{user.MXID},
- IsDirect: true,
+ Invite: []id.UserID{user.MXID},
+ IsDirect: true,
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{user.MXID}
diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go
index 203dc122..d56dc4cc 100644
--- a/bridgev2/userlogin.go
+++ b/bridgev2/userlogin.go
@@ -10,6 +10,7 @@ import (
"cmp"
"context"
"fmt"
+ "maps"
"slices"
"sync"
"time"
@@ -50,6 +51,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
+ // TODO if loading the user caused the provided userlogin to be loaded, cancel here?
+ // Currently this will double-load it
}
userLogin := &UserLogin{
UserLogin: dbUserLogin,
@@ -140,6 +143,12 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin {
return br.userLoginsByID[id]
}
+func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) {
+ br.cacheLock.Lock()
+ defer br.cacheLock.Unlock()
+ return slices.Collect(maps.Values(br.userLoginsByID))
+}
+
func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
@@ -501,9 +510,9 @@ var _ status.BridgeStateFiller = (*UserLogin)(nil)
func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState {
state.UserID = ul.UserMXID
- state.RemoteID = string(ul.ID)
+ state.RemoteID = ul.ID
state.RemoteName = ul.RemoteName
- state.RemoteProfile = &ul.RemoteProfile
+ state.RemoteProfile = ul.RemoteProfile
filler, ok := ul.Client.(status.BridgeStateFiller)
if ok {
return filler.FillBridgeState(state)
diff --git a/client.go b/client.go
index 6f746015..7062d9b9 100644
--- a/client.go
+++ b/client.go
@@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"os"
+ "runtime"
"slices"
"strconv"
"strings"
@@ -110,6 +111,8 @@ type Client struct {
// Set to true to disable automatically sleeping on 429 errors.
IgnoreRateLimit bool
+ ResponseSizeLimit int64
+
txnID int32
// Should the ?user_id= query parameter be set in requests?
@@ -139,6 +142,12 @@ type IdentityServerInfo struct {
// Use ParseUserID to extract the server name from a user ID.
// https://spec.matrix.org/v1.2/client-server-api/#server-discovery
func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) {
+ return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName)
+}
+
+const WellKnownMaxSize = 64 * 1024
+
+func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) {
wellKnownURL := url.URL{
Scheme: "https",
Host: serverName,
@@ -150,10 +159,11 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
return nil, err
}
- req.Header.Set("Accept", "application/json")
- req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)")
+ if runtime.GOOS != "js" {
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)")
+ }
- client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
@@ -162,11 +172,15 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
if resp.StatusCode == http.StatusNotFound {
return nil, nil
+ } else if resp.ContentLength > WellKnownMaxSize {
+ return nil, errors.New(".well-known response too large")
}
- data, err := io.ReadAll(resp.Body)
+ data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize))
if err != nil {
return nil, err
+ } else if len(data) >= WellKnownMaxSize {
+ return nil, errors.New(".well-known response too large")
}
var wellKnown ClientWellKnown
@@ -317,6 +331,7 @@ const (
LogBodyContextKey contextKey = iota
LogRequestIDContextKey
MaxAttemptsContextKey
+ SyncTokenContextKey
)
func (cli *Client) RequestStart(req *http.Request) {
@@ -371,7 +386,14 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
}
}
if body := req.Context().Value(LogBodyContextKey); body != nil {
- evt.Interface("req_body", body)
+ switch typedLogBody := body.(type) {
+ case json.RawMessage:
+ evt.RawJSON("req_body", typedLogBody)
+ case string:
+ evt.Str("req_body", typedLogBody)
+ default:
+ panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body))
+ }
}
if errors.Is(err, context.Canceled) {
evt.Msg("Request canceled")
@@ -388,32 +410,43 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin
return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody})
}
-type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error)
+type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error)
type FullRequest struct {
- Method string
- URL string
- Headers http.Header
- RequestJSON interface{}
- RequestBytes []byte
- RequestBody io.Reader
- RequestLength int64
- ResponseJSON interface{}
- MaxAttempts int
- BackoffDuration time.Duration
- SensitiveContent bool
- Handler ClientResponseHandler
- DontReadResponse bool
- Logger *zerolog.Logger
- Client *http.Client
+ Method string
+ URL string
+ Headers http.Header
+ RequestJSON interface{}
+ RequestBytes []byte
+ RequestBody io.Reader
+ RequestLength int64
+ ResponseJSON interface{}
+ MaxAttempts int
+ BackoffDuration time.Duration
+ SensitiveContent bool
+ Handler ClientResponseHandler
+ DontReadResponse bool
+ ResponseSizeLimit int64
+ Logger *zerolog.Logger
+ Client *http.Client
}
var requestID int32
var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes"
func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) {
+ reqID := atomic.AddInt32(&requestID, 1)
+ logger := zerolog.Ctx(ctx)
+ if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
+ logger = params.Logger
+ }
+ ctx = logger.With().
+ Int32("req_id", reqID).
+ Logger().WithContext(ctx)
+
var logBody any
- reqBody := params.RequestBody
+ var reqBody io.Reader
+ var reqLen int64
if params.RequestJSON != nil {
jsonStr, err := json.Marshal(params.RequestJSON)
if err != nil {
@@ -424,33 +457,38 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
}
if params.SensitiveContent && !logSensitiveContent {
logBody = ""
+ } else if len(jsonStr) > 32768 {
+ logBody = fmt.Sprintf("", len(jsonStr))
} else {
- logBody = params.RequestJSON
+ logBody = json.RawMessage(jsonStr)
}
reqBody = bytes.NewReader(jsonStr)
+ reqLen = int64(len(jsonStr))
} else if params.RequestBytes != nil {
logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes))
reqBody = bytes.NewReader(params.RequestBytes)
- params.RequestLength = int64(len(params.RequestBytes))
- } else if params.RequestLength > 0 && params.RequestBody != nil {
- logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
+ reqLen = int64(len(params.RequestBytes))
+ } else if params.RequestBody != nil {
+ logBody = ""
+ reqLen = -1
+ if params.RequestLength > 0 {
+ logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
+ reqLen = params.RequestLength
+ } else if params.RequestLength == 0 {
+ zerolog.Ctx(ctx).Warn().
+ Msg("RequestBody passed without specifying request length")
+ }
+ reqBody = params.RequestBody
if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok {
// Prevent HTTP from closing the request body, it might be needed for retries
reqBody = nopCloseSeeker{rsc}
}
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
params.RequestJSON = struct{}{}
- logBody = params.RequestJSON
+ logBody = json.RawMessage("{}")
reqBody = bytes.NewReader([]byte("{}"))
+ reqLen = 2
}
- reqID := atomic.AddInt32(&requestID, 1)
- logger := zerolog.Ctx(ctx)
- if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
- logger = params.Logger
- }
- ctx = logger.With().
- Int32("req_id", reqID).
- Logger().WithContext(ctx)
ctx = context.WithValue(ctx, LogBodyContextKey, logBody)
ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID))
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
@@ -466,9 +504,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
if params.RequestJSON != nil {
req.Header.Set("Content-Type", "application/json")
}
- if params.RequestLength > 0 && params.RequestBody != nil {
- req.ContentLength = params.RequestLength
- }
+ req.ContentLength = reqLen
return req, nil
}
@@ -513,14 +549,31 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque
params.Handler = handleNormalResponse
}
}
- req.Header.Set("User-Agent", cli.UserAgent)
+ if cli.UserAgent != "" {
+ req.Header.Set("User-Agent", cli.UserAgent)
+ }
if len(cli.AccessToken) > 0 {
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
}
+ if params.ResponseSizeLimit == 0 {
+ params.ResponseSizeLimit = cli.ResponseSizeLimit
+ }
+ if params.ResponseSizeLimit == 0 {
+ params.ResponseSizeLimit = DefaultResponseSizeLimit
+ }
if params.Client == nil {
params.Client = cli.Client
}
- return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client)
+ return cli.executeCompiledRequest(
+ req,
+ params.MaxAttempts-1,
+ params.BackoffDuration,
+ params.ResponseJSON,
+ params.Handler,
+ params.DontReadResponse,
+ params.ResponseSizeLimit,
+ params.Client,
+ )
}
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
@@ -531,7 +584,17 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
-func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
+func (cli *Client) doRetry(
+ req *http.Request,
+ cause error,
+ retries int,
+ backoff time.Duration,
+ responseJSON any,
+ handler ClientResponseHandler,
+ dontReadResponse bool,
+ sizeLimit int64,
+ client *http.Client,
+) ([]byte, *http.Response, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
var err error
@@ -553,21 +616,37 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff
}
}
log.Warn().Err(cause).
+ Str("method", req.Method).
+ Str("url", req.URL.String()).
Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Request failed, retrying")
select {
case <-time.After(backoff):
case <-req.Context().Done():
- return nil, nil, req.Context().Err()
+ if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) {
+ return nil, nil, req.Context().Err()
+ }
}
if cli.UpdateRequestOnRetry != nil {
req = cli.UpdateRequestOnRetry(req, cause)
}
- return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client)
+ return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client)
}
-func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) {
- contents, err := io.ReadAll(res.Body)
+func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) {
+ if res.ContentLength > limit {
+ return nil, HTTPError{
+ Request: req,
+ Response: res,
+
+ Message: "not reading response",
+ WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
+ }
+ }
+ contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1))
+ if err == nil && len(contents) > int(limit) {
+ err = ErrBodyReadReachedLimit
+ }
if err != nil {
return nil, HTTPError{
Request: req,
@@ -588,17 +667,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) {
}
}
-func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
+func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
log := zerolog.Ctx(req.Context())
file, err := os.CreateTemp("", "mautrix-response-")
if err != nil {
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
- _, err = handleNormalResponse(req, res, responseJSON)
+ _, err = handleNormalResponse(req, res, responseJSON, limit)
return nil, err
}
defer closeTemp(log, file)
- if _, err = io.Copy(file, res.Body); err != nil {
+ var n int64
+ if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil {
return nil, fmt.Errorf("failed to copy response to file: %w", err)
+ } else if n > limit {
+ return nil, ErrBodyReadReachedLimit
} else if _, err = file.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err)
} else if err = json.NewDecoder(file).Decode(responseJSON); err != nil {
@@ -608,12 +690,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac
}
}
-func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
+func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
return nil, nil
}
-func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
- if contents, err := readResponseBody(req, res); err != nil {
+func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
+ if contents, err := readResponseBody(req, res, limit); err != nil {
return nil, err
} else if responseJSON == nil {
return contents, nil
@@ -631,8 +713,13 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in
}
}
+const ErrorResponseSizeLimit = 512 * 1024
+
+var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024
+
func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
- contents, err := readResponseBody(req, res)
+ defer res.Body.Close()
+ contents, err := readResponseBody(req, res, ErrorResponseSizeLimit)
if err != nil {
return contents, err
}
@@ -651,17 +738,31 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
-func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
+func (cli *Client) executeCompiledRequest(
+ req *http.Request,
+ retries int,
+ backoff time.Duration,
+ responseJSON any,
+ handler ClientResponseHandler,
+ dontReadResponse bool,
+ sizeLimit int64,
+ client *http.Client,
+) ([]byte, *http.Response, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := client.Do(req)
- duration := time.Now().Sub(startTime)
+ duration := time.Since(startTime)
if res != nil && !dontReadResponse {
defer res.Body.Close()
}
if err != nil {
- if retries > 0 && !errors.Is(err, context.Canceled) {
- return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client)
+ // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry
+ canRetry := !errors.Is(err, context.Canceled) ||
+ errors.Is(context.Cause(req.Context()), ErrContextCancelRetry)
+ if retries > 0 && canRetry {
+ return cli.doRetry(
+ req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
+ )
}
err = HTTPError{
Request: req,
@@ -676,7 +777,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
- return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client)
+ return cli.doRetry(
+ req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
+ )
}
var body []byte
@@ -684,7 +787,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
} else {
- body, err = handler(req, res, responseJSON)
+ body, err = handler(req, res, responseJSON, sizeLimit)
cli.LogRequestDone(req, res, nil, err, len(body), duration)
}
return body, res, err
@@ -744,7 +847,7 @@ func (req *ReqSync) BuildQuery() map[string]string {
query["full_state"] = "true"
}
if req.UseStateAfter {
- query["org.matrix.msc4222.use_state_after"] = "true"
+ query["use_state_after"] = "true"
}
if req.BeeperStreaming {
query["com.beeper.streaming"] = "true"
@@ -768,7 +871,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
}
start := time.Now()
_, err = cli.MakeFullRequest(ctx, fullReq)
- duration := time.Now().Sub(start)
+ duration := time.Since(start)
timeout := time.Duration(req.Timeout) * time.Millisecond
buffer := 10 * time.Second
if req.Since == "" {
@@ -815,7 +918,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp
return
}
-func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
+func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
var bodyBytes []byte
bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
@@ -839,7 +942,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (
// Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
//
// Registers with kind=user. For kind=guest, see RegisterGuest.
-func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
+func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
u := cli.BuildClientURL("v3", "register")
return cli.register(ctx, u, req)
}
@@ -848,7 +951,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegiste
// with kind=guest.
//
// For kind=user, see Register.
-func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
+func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
query := map[string]string{
"kind": "guest",
}
@@ -871,8 +974,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe
// panic(err)
// }
// token := res.AccessToken
-func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
- res, uia, err := cli.Register(ctx, req)
+func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) {
+ _, uia, err := cli.Register(ctx, req)
if err != nil && uia == nil {
return nil, err
} else if uia == nil {
@@ -881,7 +984,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe
return nil, errors.New("server does not support m.login.dummy")
}
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
- res, _, err = cli.Register(ctx, req)
+ res, _, err := cli.Register(ctx, req)
if err != nil {
return nil, err
}
@@ -1045,8 +1148,19 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs
return
}
+func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) {
+ urlPath := cli.BuildClientURL("v3", "user_directory", "search")
+ _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{
+ SearchTerm: query,
+ Limit: limit,
+ }, &resp)
+ return
+}
+
func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) {
- if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) {
+ supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms)
+ supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms)
+ if cli.SpecVersions != nil && !supportsUnstable && !supportsStable {
err = fmt.Errorf("server does not support fetching mutual rooms")
return
}
@@ -1056,7 +1170,10 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex
if len(extras) > 0 {
query["from"] = extras[0].From
}
- urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
+ urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query)
+ if !supportsStable && supportsUnstable {
+ urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
+ }
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
@@ -1078,8 +1195,7 @@ func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via
// GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname
func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) {
- urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname")
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
+ err = cli.GetProfileField(ctx, mxid, "displayname", &resp)
return
}
@@ -1090,41 +1206,47 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay
// SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname
func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) {
- urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname")
- s := struct {
- DisplayName string `json:"displayname"`
- }{displayName}
- _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil)
- return
+ return cli.SetProfileField(ctx, "displayname", displayName)
}
-// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
-func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) {
- urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
+// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
+func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) {
+ urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
+ if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
+ urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
+ }
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{
key: value,
}, nil)
return
}
-// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
-func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) {
- urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
+// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
+func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) {
+ urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
+ if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
+ urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
+ }
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return
}
+// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname
+func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) {
+ urlPath := cli.BuildClientURL("v3", "profile", userID, key)
+ if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
+ urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
+ }
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into)
+ return
+}
+
// GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url
func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) {
- urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url")
s := struct {
AvatarURL id.ContentURI `json:"avatar_url"`
}{}
-
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s)
- if err != nil {
- return
- }
+ err = cli.GetProfileField(ctx, mxid, "avatar_url", &s)
url = s.AvatarURL
return
}
@@ -1216,6 +1338,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
+ if req.UnstableStickyDuration > 0 {
+ queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
+ }
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
var isEncrypted bool
@@ -1239,9 +1364,51 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
return
}
-// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
+// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint.
+// contentJSON should be a value that can be encoded as JSON using json.Marshal.
+func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
+ var req ReqSendEvent
+ if len(extra) > 0 {
+ req = extra[0]
+ }
+
+ var txnID string
+ if len(req.TransactionID) > 0 {
+ txnID = req.TransactionID
+ } else {
+ txnID = cli.TxnID()
+ }
+
+ queryParams := map[string]string{}
+ if req.Timestamp > 0 {
+ queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
+ }
+
+ if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted {
+ var isEncrypted bool
+ isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
+ if err != nil {
+ err = fmt.Errorf("failed to check if room is encrypted: %w", err)
+ return
+ }
+ if isEncrypted {
+ if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
+ err = fmt.Errorf("failed to encrypt event: %w", err)
+ return
+ }
+ eventType = event.EventEncrypted
+ }
+ }
+
+ urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID}
+ urlPath := cli.BuildURLWithQuery(urlData, queryParams)
+ _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
+ return
+}
+
+// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
-func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
+func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
var req ReqSendEvent
if len(extra) > 0 {
req = extra[0]
@@ -1251,9 +1418,18 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
if req.MeowEventID != "" {
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
}
+ if req.TransactionID != "" {
+ queryParams["fi.mau.transaction_id"] = req.TransactionID
+ }
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
+ if req.UnstableStickyDuration > 0 {
+ queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
+ }
+ if req.Timestamp > 0 {
+ queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
+ }
urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
@@ -1266,14 +1442,38 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
// SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
+//
+// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead.
func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
- urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
- "ts": strconv.FormatInt(ts, 10),
+ resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{
+ Timestamp: ts,
})
- _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
- if err == nil && cli.StateStore != nil {
- cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
+ return
+}
+
+func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) {
+ query := map[string]string{}
+ if req.DelayID != "" {
+ query["delay_id"] = string(req.DelayID)
}
+ if req.Status != "" {
+ query["status"] = string(req.Status)
+ }
+ if req.NextBatch != "" {
+ query["next_batch"] = req.NextBatch
+ }
+
+ urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query)
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp)
+
+ // Migration: merge old keys with new ones
+ if resp != nil {
+ resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...)
+ resp.DelayedEvents = nil
+ resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...)
+ resp.FinalisedEvents = nil
+ }
+
return
}
@@ -1364,6 +1564,10 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re
Msg("Failed to update creator membership in state store after creating room")
}
for _, evt := range req.InitialState {
+ evt.RoomID = resp.RoomID
+ if evt.StateKey == nil {
+ evt.StateKey = ptr.Ptr("")
+ }
UpdateStateStore(ctx, cli.StateStore, evt)
}
inviteMembership := event.MembershipInvite
@@ -1378,9 +1582,6 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re
Msg("Failed to update membership in state store after creating room")
}
}
- for _, evt := range req.InitialState {
- cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
- }
}
return
}
@@ -1551,22 +1752,34 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy
"format": "event",
})
_, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt)
- if err == nil && cli.StateStore != nil {
- UpdateStateStore(ctx, cli.StateStore, evt)
- }
if evt != nil {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
+ if evt.RoomID == "" {
+ evt.RoomID = roomID
+ }
+ }
+ if err == nil && cli.StateStore != nil {
+ UpdateStateStore(ctx, cli.StateStore, evt)
}
return
}
// parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map.
-func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
+func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
+ if res.ContentLength > limit {
+ return nil, HTTPError{
+ Request: req,
+ Response: res,
+
+ Message: "not reading response",
+ WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
+ }
+ }
response := make(RoomStateMap)
responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event)
*responsePtr = response
- dec := json.NewDecoder(res.Body)
+ dec := json.NewDecoder(io.LimitReader(res.Body, limit))
arrayStart, err := dec.Token()
if err != nil {
@@ -1600,6 +1813,8 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter
return nil, nil
}
+type RoomStateMap = map[event.Type]map[string]*event.Event
+
// State gets all state in a room.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {
@@ -1609,12 +1824,21 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
ResponseJSON: &stateMap,
Handler: parseRoomStateArray,
})
+ if stateMap != nil {
+ pls, ok := stateMap[event.StatePowerLevels][""]
+ if ok {
+ pls.Content.AsPowerLevels().CreateEvent = stateMap[event.StateCreate][""]
+ }
+ }
if err == nil && cli.StateStore != nil {
for evtType, evts := range stateMap {
if evtType == event.StateMember {
continue
}
for _, evt := range evts {
+ if evt.RoomID == "" {
+ evt.RoomID = roomID
+ }
UpdateStateStore(ctx, cli.StateStore, evt)
}
}
@@ -1673,6 +1897,9 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa
}
func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
+ if mxcURL.IsEmpty() {
+ return nil, fmt.Errorf("empty mxc uri provided to Download")
+ }
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
Method: http.MethodGet,
URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID),
@@ -1681,6 +1908,41 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re
return resp, err
}
+type DownloadThumbnailExtra struct {
+ Method string
+ Animated bool
+}
+
+func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) {
+ if mxcURL.IsEmpty() {
+ return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail")
+ }
+ if len(extras) > 1 {
+ panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras)))
+ }
+ var extra DownloadThumbnailExtra
+ if len(extras) == 1 {
+ extra = extras[0]
+ }
+ path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID}
+ query := map[string]string{
+ "height": strconv.Itoa(height),
+ "width": strconv.Itoa(width),
+ }
+ if extra.Method != "" {
+ query["method"] = extra.Method
+ }
+ if extra.Animated {
+ query["animated"] = "true"
+ }
+ _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
+ Method: http.MethodGet,
+ URL: cli.BuildURLWithQuery(path, query),
+ DontReadResponse: true,
+ })
+ return resp, err
+}
+
func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) {
resp, err := cli.Download(ctx, mxcURL)
if err != nil {
@@ -1727,10 +1989,15 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr
}
req.MXC = resp.ContentURI
req.UnstableUploadURL = resp.UnstableUploadURL
+ if req.AsyncContext == nil {
+ req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background())
+ }
go func() {
- _, err = cli.UploadMedia(ctx, req)
+ _, err = cli.UploadMedia(req.AsyncContext, req)
if err != nil {
- cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed")
+ zerolog.Ctx(req.AsyncContext).Err(err).
+ Stringer("mxc", req.MXC).
+ Msg("Async upload of media failed")
}
}()
return resp, nil
@@ -1766,6 +2033,7 @@ type ReqUploadMedia struct {
ContentType string
FileName string
+ AsyncContext context.Context
DoneCallback func()
// MXC specifies an existing MXC URI which doesn't have content yet to upload into.
@@ -1778,14 +2046,19 @@ type ReqUploadMedia struct {
}
func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) {
- cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
+ cli.Log.Debug().
+ Str("url", url).
+ Int64("content_length", contentLength).
+ Msg("Uploading media to external URL")
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
req.Header.Set("Content-Type", contentType)
- req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)")
+ if cli.UserAgent != "" {
+ req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)")
+ }
if cli.ExternalClient != nil {
return cli.ExternalClient.Do(req)
@@ -1825,8 +2098,16 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*
Msg("Error uploading media to external URL, not retrying")
return nil, err
}
- cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err).
+ backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries)
+ cli.Log.Warn().Err(err).
+ Str("url", data.UnstableUploadURL).
+ Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Error uploading media to external URL, retrying")
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
retries--
_, err = readerSeeker.Seek(0, io.SeekStart)
if err != nil {
@@ -2406,15 +2687,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req
return err
}
-func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error {
+func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error {
urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err
}
-func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error {
+func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error {
urlPath := cli.BuildClientURL("v3", "delete_devices")
- _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
+ _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil)
return err
}
@@ -2423,7 +2704,7 @@ type UIACallback = func(*RespUserInteractive) interface{}
// UploadCrossSigningKeys uploads the given cross-signing keys to the server.
// Because the endpoint requires user-interactive authentication a callback must be provided that,
// given the UI auth parameters, produces the required result (or nil to end the flow).
-func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error {
+func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error {
content, err := cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"),
@@ -2505,24 +2786,61 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri
return err
}
-// BatchSend sends a batch of historical events into a room. This is only available for appservices.
+// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin.
//
-// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead.
-func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) {
- path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"}
- query := map[string]string{
- "prev_event_id": req.PrevEventID.String(),
+// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
+func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) {
+ urlPath := cli.BuildClientURL("v3", "admin", "whois", userID)
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
+ return
+}
+
+func (cli *Client) makeMSC4323URL(action string, target id.UserID) string {
+ if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) {
+ return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target)
+ } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) {
+ return cli.BuildClientURL("v1", "admin", action, target)
}
- if req.BeeperNewMessages {
- query["com.beeper.new_messages"] = "true"
+ return ""
+}
+
+// GetSuspendedStatus uses MSC4323 to check if a user is suspended.
+func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) {
+ urlPath := cli.makeMSC4323URL("suspend", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
}
- if req.BeeperMarkReadBy != "" {
- query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String()
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
+ return
+}
+
+// GetLockStatus uses MSC4323 to check if a user is locked.
+func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) {
+ urlPath := cli.makeMSC4323URL("lock", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
}
- if len(req.BatchID) > 0 {
- query["batch_id"] = req.BatchID.String()
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
+ return
+}
+
+// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended.
+func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) {
+ urlPath := cli.makeMSC4323URL("suspend", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
}
- _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp)
+ _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res)
+ return
+}
+
+// SetLockStatus uses MSC4323 to set whether a user account is locked.
+func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) {
+ urlPath := cli.makeMSC4323URL("lock", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ }
+ _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res)
return
}
diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go
new file mode 100644
index 00000000..c2846427
--- /dev/null
+++ b/client_ephemeral_test.go
@@ -0,0 +1,158 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mautrix_test
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) {
+ roomID := id.RoomID("!room:example.com")
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+ txnID := "txn-123"
+
+ var gotPath string
+ var gotQueryTS string
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotQueryTS = r.URL.Query().Get("ts")
+ assert.Equal(t, http.MethodPut, r.Method)
+ _, _ = w.Write([]byte(`{"event_id":"$evt"}`))
+ }))
+ defer ts.Close()
+
+ cli, err := mautrix.NewClient(ts.URL, "", "")
+ require.NoError(t, err)
+
+ _, err = cli.BeeperSendEphemeralEvent(
+ context.Background(),
+ roomID,
+ evtType,
+ map[string]any{"foo": "bar"},
+ mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234},
+ )
+ require.NoError(t, err)
+
+ assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/"))
+ assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID))
+ assert.Equal(t, "1234", gotQueryTS)
+}
+
+func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`))
+ }))
+ defer ts.Close()
+
+ cli, err := mautrix.NewClient(ts.URL, "", "")
+ require.NoError(t, err)
+
+ _, err = cli.BeeperSendEphemeralEvent(
+ context.Background(),
+ id.RoomID("!room:example.com"),
+ event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType},
+ map[string]any{"foo": "bar"},
+ )
+ require.Error(t, err)
+ assert.True(t, errors.Is(err, mautrix.MUnrecognized))
+}
+
+func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) {
+ roomID := id.RoomID("!room:example.com")
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+ txnID := "txn-encrypted"
+
+ stateStore := mautrix.NewMemoryStateStore()
+ err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{
+ Algorithm: id.AlgorithmMegolmV1,
+ })
+ require.NoError(t, err)
+
+ fakeCrypto := &fakeCryptoHelper{
+ encryptedContent: &event.EncryptedEventContent{
+ Algorithm: id.AlgorithmMegolmV1,
+ MegolmCiphertext: []byte("ciphertext"),
+ },
+ }
+
+ var gotPath string
+ var gotBody map[string]any
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ assert.Equal(t, http.MethodPut, r.Method)
+ err := json.NewDecoder(r.Body).Decode(&gotBody)
+ require.NoError(t, err)
+ _, _ = w.Write([]byte(`{"event_id":"$evt"}`))
+ }))
+ defer ts.Close()
+
+ cli, err := mautrix.NewClient(ts.URL, "", "")
+ require.NoError(t, err)
+ cli.StateStore = stateStore
+ cli.Crypto = fakeCrypto
+
+ _, err = cli.BeeperSendEphemeralEvent(
+ context.Background(),
+ roomID,
+ evtType,
+ map[string]any{"foo": "bar"},
+ mautrix.ReqSendEvent{TransactionID: txnID},
+ )
+ require.NoError(t, err)
+
+ assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID))
+ assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"])
+ assert.Equal(t, 1, fakeCrypto.encryptCalls)
+ assert.Equal(t, roomID, fakeCrypto.lastRoomID)
+ assert.Equal(t, evtType, fakeCrypto.lastEventType)
+}
+
+type fakeCryptoHelper struct {
+ encryptCalls int
+ lastRoomID id.RoomID
+ lastEventType event.Type
+ lastEncryptInput any
+ encryptedContent *event.EncryptedEventContent
+}
+
+func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) {
+ f.encryptCalls++
+ f.lastRoomID = roomID
+ f.lastEventType = eventType
+ f.lastEncryptInput = content
+ return f.encryptedContent, nil
+}
+
+func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) {
+ return nil, nil
+}
+
+func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool {
+ return false
+}
+
+func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) {
+}
+
+func (f *fakeCryptoHelper) Init(context.Context) error {
+ return nil
+}
diff --git a/commands/container.go b/commands/container.go
index bc685b7b..9b909b75 100644
--- a/commands/container.go
+++ b/commands/container.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,14 +8,20 @@ package commands
import (
"fmt"
+ "slices"
"strings"
"sync"
+
+ "go.mau.fi/util/exmaps"
+
+ "maunium.net/go/mautrix/event/cmdschema"
)
type CommandContainer[MetaType any] struct {
commands map[string]*Handler[MetaType]
aliases map[string]string
lock sync.RWMutex
+ parent *Handler[MetaType]
}
func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
@@ -25,6 +31,29 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
}
}
+func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent {
+ data := make(exmaps.Set[*Handler[MetaType]])
+ cont.collectHandlers(data)
+ specs := make([]*cmdschema.EventContent, 0, data.Size())
+ for handler := range data.Iter() {
+ if handler.Parameters != nil {
+ specs = append(specs, handler.Spec())
+ }
+ }
+ return specs
+}
+
+func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) {
+ cont.lock.RLock()
+ defer cont.lock.RUnlock()
+ for _, handler := range cont.commands {
+ into.Add(handler)
+ if handler.subcommandContainer != nil {
+ handler.subcommandContainer.collectHandlers(into)
+ }
+ }
+}
+
// Register registers the given command handlers.
func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) {
if cont == nil {
@@ -32,7 +61,10 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType])
}
cont.lock.Lock()
defer cont.lock.Unlock()
- for _, handler := range handlers {
+ for i, handler := range handlers {
+ if handler == nil {
+ panic(fmt.Errorf("handler #%d is nil", i+1))
+ }
cont.registerOne(handler)
}
}
@@ -45,6 +77,10 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType])
} else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists {
panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget))
}
+ if !slices.Contains(handler.parents, cont.parent) {
+ handler.parents = append(handler.parents, cont.parent)
+ handler.nestedNameCache = nil
+ }
cont.commands[handler.Name] = handler
for _, alias := range handler.Aliases {
if strings.ToLower(alias) != alias {
diff --git a/commands/event.go b/commands/event.go
index 77a3c0d2..76d6c9f0 100644
--- a/commands/event.go
+++ b/commands/event.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,7 @@ package commands
import (
"context"
+ "encoding/json"
"fmt"
"strings"
@@ -35,6 +36,8 @@ type Event[MetaType any] struct {
// RawArgs is the same as args, but without the splitting by whitespace.
RawArgs string
+ StructuredArgs json.RawMessage
+
Ctx context.Context
Log *zerolog.Logger
Proc *Processor[MetaType]
@@ -61,7 +64,7 @@ var IDHTMLParser = &format.HTMLParser{
}
// ParseEvent parses a message into a command event struct.
-func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] {
+func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" {
return nil
@@ -70,12 +73,34 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta
if content.Format == event.FormatHTML {
text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx))
}
+ if content.MSC4391BotCommand != nil {
+ if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 {
+ return nil
+ }
+ wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand)
+ wrapped.RawInput = text
+ return wrapped
+ }
if len(text) == 0 {
return nil
}
return RawTextToEvent[MetaType](ctx, evt, text)
}
+func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] {
+ commandParts := strings.Split(content.Command, " ")
+ return &Event[MetaType]{
+ Event: evt,
+ // Fake a command and args to let the subcommand finder in Process work.
+ Command: commandParts[0],
+ Args: commandParts[1:],
+ Ctx: ctx,
+ Log: zerolog.Ctx(ctx),
+
+ StructuredArgs: content.Arguments,
+ }
+}
+
func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] {
parts := strings.Fields(text)
if len(parts) == 0 {
@@ -188,3 +213,25 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) {
evt.RawArgs = arg + " " + evt.RawArgs
evt.Args = append([]string{arg}, evt.Args...)
}
+
+func (evt *Event[MetaType]) ParseArgs(into any) error {
+ return json.Unmarshal(evt.StructuredArgs, into)
+}
+
+func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) {
+ err = evt.ParseArgs(&into)
+ return
+}
+
+func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) {
+ return func(evt *Event[MetaType]) {
+ parsed, err := ParseArgs[T, MetaType](evt)
+ if err != nil {
+ evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct")
+ // TODO better error, usage info? deduplicate with Process
+ evt.Reply("Failed to parse arguments: %v", err)
+ return
+ }
+ fn(evt, parsed)
+ }
+}
diff --git a/commands/handler.go b/commands/handler.go
index b01d594f..56f27f06 100644
--- a/commands/handler.go
+++ b/commands/handler.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,9 @@ package commands
import (
"strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/event/cmdschema"
)
type Handler[MetaType any] struct {
@@ -25,12 +28,63 @@ type Handler[MetaType any] struct {
// Event.ShiftArg will likely be useful for implementing such parameters.
PreFunc func(ce *Event[MetaType])
+ // Description is a short description of the command.
+ Description *event.ExtensibleTextContainer
+ // Parameters is a description of structured command parameters.
+ // If set, the StructuredArgs field of Event will be populated.
+ Parameters []*cmdschema.Parameter
+ TailParam string
+
+ parents []*Handler[MetaType]
+ nestedNameCache []string
subcommandContainer *CommandContainer[MetaType]
}
+func (h *Handler[MetaType]) NestedNames() []string {
+ if h.nestedNameCache != nil {
+ return h.nestedNameCache
+ }
+ nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents))
+ for _, parent := range h.parents {
+ if parent == nil {
+ nestedNames = append(nestedNames, h.Name)
+ nestedNames = append(nestedNames, h.Aliases...)
+ } else {
+ for _, parentName := range parent.NestedNames() {
+ nestedNames = append(nestedNames, parentName+" "+h.Name)
+ for _, alias := range h.Aliases {
+ nestedNames = append(nestedNames, parentName+" "+alias)
+ }
+ }
+ }
+ }
+ h.nestedNameCache = nestedNames
+ return nestedNames
+}
+
+func (h *Handler[MetaType]) Spec() *cmdschema.EventContent {
+ names := h.NestedNames()
+ return &cmdschema.EventContent{
+ Command: names[0],
+ Aliases: names[1:],
+ Parameters: h.Parameters,
+ Description: h.Description,
+ TailParam: h.TailParam,
+ }
+}
+
+func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) {
+ if h.Parameters == nil {
+ h.Parameters = other.Parameters
+ h.TailParam = other.TailParam
+ }
+ h.Func = other.Func
+}
+
func (h *Handler[MetaType]) initSubcommandContainer() {
if len(h.Subcommands) > 0 {
h.subcommandContainer = NewCommandContainer[MetaType]()
+ h.subcommandContainer.parent = h
h.subcommandContainer.Register(h.Subcommands...)
} else {
h.subcommandContainer = nil
diff --git a/commands/processor.go b/commands/processor.go
index 9341329b..80f6745d 100644
--- a/commands/processor.go
+++ b/commands/processor.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
case event.EventReaction:
parsed = proc.ParseReaction(ctx, evt)
case event.EventMessage:
- parsed = ParseEvent[MetaType](ctx, evt)
+ parsed = proc.ParseEvent(ctx, evt)
}
- if parsed == nil || !proc.PreValidator.Validate(parsed) {
+ if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) {
return
}
parsed.Proc = proc
@@ -107,6 +107,12 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
break
}
}
+ if parsed.StructuredArgs != nil && len(parsed.Args) > 0 {
+ // TODO allow unknown command handlers to be called?
+ // The client sent MSC4391 data, but the target command wasn't found
+ log.Debug().Msg("Didn't find handler for MSC4391 command")
+ return
+ }
logWith := log.With().
Str("command", parsed.Command).
@@ -116,11 +122,31 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
}
if proc.LogArgs {
logWith = logWith.Strs("args", parsed.Args)
+ if parsed.StructuredArgs != nil {
+ logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs)
+ }
}
log = logWith.Logger()
parsed.Ctx = log.WithContext(ctx)
parsed.Log = &log
+ if handler.Parameters != nil && parsed.StructuredArgs == nil {
+ // The handler wants structured parameters, but the client didn't send MSC4391 data
+ var err error
+ parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs)
+ if err != nil {
+ log.Debug().Err(err).Msg("Failed to parse structured arguments")
+ // TODO better error, usage info? deduplicate with WithParsedArgs
+ parsed.Reply("Failed to parse arguments: %v", err)
+ return
+ }
+ if proc.LogArgs {
+ log.UpdateContext(func(c zerolog.Context) zerolog.Context {
+ return c.RawJSON("structured_args", parsed.StructuredArgs)
+ })
+ }
+ }
+
log.Debug().Msg("Processing command")
handler.Func(parsed)
}
diff --git a/commands/reactions.go b/commands/reactions.go
index 0df372e5..0d316219 100644
--- a/commands/reactions.go
+++ b/commands/reactions.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,7 @@ package commands
import (
"context"
+ "encoding/json"
"strings"
"github.com/rs/zerolog"
@@ -19,6 +20,11 @@ import (
const ReactionCommandsKey = "fi.mau.reaction_commands"
const ReactionMultiUseKey = "fi.mau.reaction_multi_use"
+type ReactionCommandData struct {
+ Command string `json:"command"`
+ Args any `json:"args,omitempty"`
+}
+
func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.ReactionEventContent)
if !ok {
@@ -67,21 +73,33 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E
Msg("Reaction command not found in target event")
return nil
}
- cmdString, ok := rawCmd.(string)
- if !ok {
+ var wrappedEvt *Event[MetaType]
+ switch typedCmd := rawCmd.(type) {
+ case string:
+ wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd)
+ case map[string]any:
+ var input event.MSC4391BotCommandInput
+ if marshaled, err := json.Marshal(typedCmd); err != nil {
+
+ } else if err = json.Unmarshal(marshaled, &input); err != nil {
+
+ } else {
+ wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input)
+ }
+ }
+ if wrappedEvt == nil {
zerolog.Ctx(ctx).Debug().
Stringer("target_event_id", evtID).
Str("reaction_key", content.RelatesTo.Key).
Msg("Reaction command data is invalid")
return nil
}
- wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString)
wrappedEvt.Proc = proc
wrappedEvt.Redact()
if !isMultiUse {
DeleteAllReactions(ctx, proc.Client, evt)
}
- if cmdString == "" {
+ if wrappedEvt.Command == "" {
return nil
}
return wrappedEvt
diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go
index bb03f706..d6611dc9 100644
--- a/crypto/aescbc/aes_cbc_test.go
+++ b/crypto/aescbc/aes_cbc_test.go
@@ -7,11 +7,13 @@
package aescbc_test
import (
- "bytes"
"crypto/aes"
"crypto/rand"
"testing"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
"maunium.net/go/mautrix/crypto/aescbc"
)
@@ -22,32 +24,23 @@ func TestAESCBC(t *testing.T) {
// The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256)
key := make([]byte, 32)
_, err = rand.Read(key)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
iv := make([]byte, aes.BlockSize)
_, err = rand.Read(iv)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
plaintext = []byte("secret message for testing")
//increase to next block size
for len(plaintext)%8 != 0 {
plaintext = append(plaintext, []byte("-")...)
}
- if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil {
- t.Fatal(err)
- }
+ ciphertext, err = aescbc.Encrypt(key, iv, plaintext)
+ require.NoError(t, err)
resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext)
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err)
- if string(resultPlainText) != string(plaintext) {
- t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext)
- }
+ assert.Equal(t, string(resultPlainText), string(plaintext))
}
func TestAESCBCCase1(t *testing.T) {
@@ -61,18 +54,10 @@ func TestAESCBCCase1(t *testing.T) {
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
encrypted, err := aescbc.Encrypt(key, iv, input)
- if err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(expected, encrypted) {
- t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected)
- }
+ require.NoError(t, err)
+ assert.Equal(t, expected, encrypted, "encrypted output does not match expected")
decrypted, err := aescbc.Decrypt(key, iv, encrypted)
- if err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(input, decrypted) {
- t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input)
- }
+ require.NoError(t, err)
+ assert.Equal(t, input, decrypted, "decrypted output does not match input")
}
diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go
index cfa1c3e5..727aacbf 100644
--- a/crypto/attachment/attachments.go
+++ b/crypto/attachment/attachments.go
@@ -9,6 +9,7 @@ package attachment
import (
"crypto/aes"
"crypto/cipher"
+ "crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
@@ -20,13 +21,24 @@ import (
)
var (
- HashMismatch = errors.New("mismatching SHA-256 digest")
- UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
- UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
- InvalidKey = errors.New("failed to decode key")
- InvalidInitVector = errors.New("failed to decode initialization vector")
- InvalidHash = errors.New("failed to decode SHA-256 hash")
- ReaderClosed = errors.New("encrypting reader was already closed")
+ ErrHashMismatch = errors.New("mismatching SHA-256 digest")
+ ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
+ ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
+ ErrInvalidKey = errors.New("failed to decode key")
+ ErrInvalidInitVector = errors.New("failed to decode initialization vector")
+ ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
+ ErrReaderClosed = errors.New("encrypting reader was already closed")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ HashMismatch = ErrHashMismatch
+ UnsupportedVersion = ErrUnsupportedVersion
+ UnsupportedAlgorithm = ErrUnsupportedAlgorithm
+ InvalidKey = ErrInvalidKey
+ InvalidInitVector = ErrInvalidInitVector
+ InvalidHash = ErrInvalidHash
+ ReaderClosed = ErrReaderClosed
)
var (
@@ -84,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
if ef.decoded != nil {
return nil
} else if len(ef.Key.Key) != keyBase64Length {
- return InvalidKey
+ return ErrInvalidKey
} else if len(ef.InitVector) != ivBase64Length {
- return InvalidInitVector
+ return ErrInvalidInitVector
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
- return InvalidHash
+ return ErrInvalidHash
}
ef.decoded = &decodedKeys{}
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
if err != nil {
- return InvalidKey
+ return ErrInvalidKey
}
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
if err != nil {
- return InvalidInitVector
+ return ErrInvalidInitVector
}
if includeHash {
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
if err != nil {
- return InvalidHash
+ return ErrInvalidHash
}
}
return nil
@@ -178,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
if r.closed {
- return 0, ReaderClosed
+ return 0, ErrReaderClosed
}
if offset != 0 || whence != io.SeekStart {
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
@@ -199,15 +211,20 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
if r.closed {
- return 0, ReaderClosed
+ return 0, ErrReaderClosed
} else if r.isDecrypting && r.file.decoded == nil {
if err = r.file.PrepareForDecryption(); err != nil {
return
}
}
n, err = r.source.Read(dst)
+ if r.isDecrypting {
+ r.hash.Write(dst[:n])
+ }
r.stream.XORKeyStream(dst[:n], dst[:n])
- r.hash.Write(dst[:n])
+ if !r.isDecrypting {
+ r.hash.Write(dst[:n])
+ }
return
}
@@ -217,10 +234,8 @@ func (r *encryptingReader) Close() (err error) {
err = closer.Close()
}
if r.isDecrypting {
- var downloadedChecksum [utils.SHAHashLength]byte
- r.hash.Sum(downloadedChecksum[:])
- if downloadedChecksum != r.file.decoded.sha256 {
- return HashMismatch
+ if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
+ return ErrHashMismatch
}
} else {
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
@@ -261,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
func (ef *EncryptedFile) PrepareForDecryption() error {
if ef.Version != "v2" {
- return UnsupportedVersion
+ return ErrUnsupportedVersion
} else if ef.Key.Algorithm != "A256CTR" {
- return UnsupportedAlgorithm
+ return ErrUnsupportedAlgorithm
} else if err := ef.decodeKeys(true); err != nil {
return err
}
@@ -274,12 +289,13 @@ func (ef *EncryptedFile) PrepareForDecryption() error {
func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
if err := ef.PrepareForDecryption(); err != nil {
return err
- } else if ef.decoded.sha256 != sha256.Sum256(data) {
- return HashMismatch
- } else {
- utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
- return nil
}
+ dataHash := sha256.Sum256(data)
+ if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
+ return ErrHashMismatch
+ }
+ utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
+ return nil
}
// DecryptStream wraps the given io.Reader in order to decrypt the data.
@@ -292,9 +308,10 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser {
block, _ := aes.NewCipher(ef.decoded.key[:])
return &encryptingReader{
- stream: cipher.NewCTR(block, ef.decoded.iv[:]),
- hash: sha256.New(),
- source: reader,
- file: ef,
+ isDecrypting: true,
+ stream: cipher.NewCTR(block, ef.decoded.iv[:]),
+ hash: sha256.New(),
+ source: reader,
+ file: ef,
}
}
diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go
index d7f1394a..9fe929ab 100644
--- a/crypto/attachment/attachments_test.go
+++ b/crypto/attachment/attachments_test.go
@@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
file := parseHelloWorld()
file.Version = "foo"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, UnsupportedVersion)
+ assert.ErrorIs(t, err, ErrUnsupportedVersion)
}
func TestUnsupportedAlgorithm(t *testing.T) {
file := parseHelloWorld()
file.Key.Algorithm = "bar"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, UnsupportedAlgorithm)
+ assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
}
func TestHashMismatch(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, HashMismatch)
+ assert.ErrorIs(t, err, ErrHashMismatch)
}
func TestTooLongHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, InvalidHash)
+ assert.ErrorIs(t, err, ErrInvalidHash)
}
func TestTooShortHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "5/Gy1JftyyQ"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, InvalidHash)
+ assert.ErrorIs(t, err, ErrInvalidHash)
}
diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go
index ec551dbe..25250178 100644
--- a/crypto/backup/encryptedsessiondata.go
+++ b/crypto/backup/encryptedsessiondata.go
@@ -68,6 +68,10 @@ func calculateCompatMAC(macKey []byte) []byte {
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) {
+ return EncryptSessionDataWithPubkey(backupKey.PublicKey(), sessionData)
+}
+
+func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) (*EncryptedSessionData[T], error) {
sessionJSON, err := json.Marshal(sessionData)
if err != nil {
return nil, err
@@ -78,7 +82,7 @@ func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*Encr
return nil, err
}
- sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey())
+ sharedSecret, err := ephemeralKey.ECDH(pubkey)
if err != nil {
return nil, err
}
diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go
index d1a7f0a5..36476aa4 100644
--- a/crypto/canonicaljson/json_test.go
+++ b/crypto/canonicaljson/json_test.go
@@ -17,31 +17,43 @@ package canonicaljson
import (
"testing"
+
+ "github.com/stretchr/testify/assert"
)
-func testSortJSON(t *testing.T, input, want string) {
- got := SortJSON([]byte(input), nil)
-
- // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
- if string(CompactJSON(got, nil)) != want {
- t.Errorf("SortJSON(%q): want %q got %q", input, want, got)
- }
-}
-
func TestSortJSON(t *testing.T) {
- testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`)
- testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`,
- `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`)
- testSortJSON(t, `[true,false,null]`, `[true,false,null]`)
- testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`)
- testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`)
+ var tests = []struct {
+ input string
+ want string
+ }{
+ {"{}", "{}"},
+ {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`},
+ {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`},
+ {`[true,false,null]`, `[true,false,null]`},
+ {`[9007199254740991]`, `[9007199254740991]`},
+ {"\t\n[9007199254740991]", `[9007199254740991]`},
+ {`[true,false,null]`, `[true,false,null]`},
+ {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`},
+ {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`},
+ {`[true,false,null]`, `[true,false,null]`},
+ {`[9007199254740991]`, `[9007199254740991]`},
+ {"\t\n[9007199254740991]", `[9007199254740991]`},
+ {`[true,false,null]`, `[true,false,null]`},
+ }
+ for _, test := range tests {
+ t.Run(test.input, func(t *testing.T) {
+ got := SortJSON([]byte(test.input), nil)
+
+ // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
+ assert.EqualValues(t, test.want, string(CompactJSON(got, nil)))
+ })
+ }
}
func testCompactJSON(t *testing.T, input, want string) {
+ t.Helper()
got := string(CompactJSON([]byte(input), nil))
- if got != want {
- t.Errorf("CompactJSON(%q): want %q got %q", input, want, got)
- }
+ assert.EqualValues(t, want, got)
}
func TestCompactJSON(t *testing.T) {
@@ -74,18 +86,23 @@ func TestCompactJSON(t *testing.T) {
testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`)
}
-func testReadHex(t *testing.T, input string, want uint32) {
- got := readHexDigits([]byte(input))
- if want != got {
- t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got)
+func TestReadHex(t *testing.T) {
+ tests := []struct {
+ input string
+ want uint32
+ }{
+
+ {"0123", 0x0123},
+ {"4567", 0x4567},
+ {"89AB", 0x89AB},
+ {"CDEF", 0xCDEF},
+ {"89ab", 0x89AB},
+ {"cdef", 0xCDEF},
+ }
+ for _, test := range tests {
+ t.Run(test.input, func(t *testing.T) {
+ got := readHexDigits([]byte(test.input))
+ assert.Equal(t, test.want, got)
+ })
}
}
-
-func TestReadHex(t *testing.T) {
- testReadHex(t, "0123", 0x0123)
- testReadHex(t, "4567", 0x4567)
- testReadHex(t, "89AB", 0x89AB)
- testReadHex(t, "CDEF", 0xCDEF)
- testReadHex(t, "89ab", 0x89AB)
- testReadHex(t, "cdef", 0xCDEF)
-}
diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go
index 4094f695..5d9bf5b3 100644
--- a/crypto/cross_sign_key.go
+++ b/crypto/cross_sign_key.go
@@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
}
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
- err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
+ err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{
Master: masterKey,
SelfSigning: selfKey,
UserSigning: userKey,
diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go
index 77efab5b..223fc7b5 100644
--- a/crypto/cross_sign_pubkey.go
+++ b/crypto/cross_sign_pubkey.go
@@ -20,6 +20,20 @@ type CrossSigningPublicKeysCache struct {
UserSigningKey id.Ed25519
}
+func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) {
+ pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
+ if pubkeys != nil {
+ hasKeys = true
+ isVerified, err = mach.CryptoStore.IsKeySignedBy(
+ ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey,
+ )
+ if err != nil {
+ err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
+ }
+ }
+ return
+}
+
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache {
if mach.crossSigningPubkeys != nil {
return mach.crossSigningPubkeys
@@ -49,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
if len(dbKeys) > 0 {
masterKey, ok := dbKeys[id.XSUsageMaster]
if ok {
- selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
- userSigning, _ := dbKeys[id.XSUsageUserSigning]
+ selfSigning := dbKeys[id.XSUsageSelfSigning]
+ userSigning := dbKeys[id.XSUsageUserSigning]
return &CrossSigningPublicKeysCache{
MasterKey: masterKey.Key,
SelfSigningKey: selfSigning.Key,
diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go
index 389a9fd2..fd42880d 100644
--- a/crypto/cross_sign_ssss.go
+++ b/crypto/cross_sign_ssss.go
@@ -8,6 +8,7 @@ package crypto
import (
"context"
+ "errors"
"fmt"
"maunium.net/go/mautrix"
@@ -71,6 +72,46 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex
}, passphrase)
}
+func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error {
+ keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get default SSSS key data: %w", err)
+ }
+ key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey)
+ if errors.Is(err, ssss.ErrUnverifiableKey) {
+ mach.machOrContextLog(ctx).Warn().
+ Str("key_id", keyID).
+ Msg("SSSS key is unverifiable, trying to use without verifying")
+ } else if err != nil {
+ return err
+ }
+ err = mach.FetchCrossSigningKeysFromSSSS(ctx, key)
+ if err != nil {
+ return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
+ }
+ err = mach.SignOwnDevice(ctx, mach.OwnIdentity())
+ if err != nil {
+ return fmt.Errorf("failed to sign own device: %w", err)
+ }
+ err = mach.SignOwnMasterKey(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to sign own master key: %w", err)
+ }
+ return nil
+}
+
+func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) {
+ recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
+ if err != nil {
+ err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err)
+ } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil {
+ err = fmt.Errorf("failed to sign own device: %w", err)
+ } else if err = mach.SignOwnMasterKey(ctx); err != nil {
+ err = fmt.Errorf("failed to sign own master key: %w", err)
+ }
+ return
+}
+
// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys.
//
// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key
@@ -97,12 +138,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
// Publish cross-signing keys
err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback)
if err != nil {
- return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err)
+ return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err)
}
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
if err != nil {
- return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
+ return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
}
return key.RecoveryKey(), keysCache, nil
diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go
index b583bada..57406b11 100644
--- a/crypto/cross_sign_store.go
+++ b/crypto/cross_sign_store.go
@@ -20,36 +20,34 @@ import (
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
- log := log.With().Str("user_id", userID.String()).Logger()
+ log := log.With().Stringer("user_id", userID).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
- if currentKeys != nil {
- for curKeyUsage, curKey := range currentKeys {
- log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
- // got a new key with the same usage as an existing key
- for _, newKeyUsage := range userKeys.Usage {
- if newKeyUsage == curKeyUsage {
- if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
- // old key is not in the new key map, so we drop signatures made by it
- if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
- log.Error().Err(err).Msg("Error deleting old signatures made by user")
- } else {
- log.Debug().
- Int64("signature_count", count).
- Msg("Dropped signatures made by old key as it has been replaced")
- }
+ for curKeyUsage, curKey := range currentKeys {
+ log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
+ // got a new key with the same usage as an existing key
+ for _, newKeyUsage := range userKeys.Usage {
+ if newKeyUsage == curKeyUsage {
+ if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
+ // old key is not in the new key map, so we drop signatures made by it
+ if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
+ log.Error().Err(err).Msg("Error deleting old signatures made by user")
+ } else {
+ log.Debug().
+ Int64("signature_count", count).
+ Msg("Dropped signatures made by old key as it has been replaced")
}
- break
}
+ break
}
}
}
for _, key := range userKeys.Keys {
- log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
+ log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go
index 5e1ffd50..b70370a2 100644
--- a/crypto/cross_sign_test.go
+++ b/crypto/cross_sign_test.go
@@ -13,6 +13,8 @@ import (
"testing"
"github.com/rs/zerolog"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
@@ -24,17 +26,12 @@ var noopLogger = zerolog.Nop()
func getOlmMachine(t *testing.T) *OlmMachine {
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
- if err != nil {
- t.Fatalf("Error opening db: %v", err)
- }
+ require.NoError(t, err, "Error opening raw database")
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
- if err != nil {
- t.Fatalf("Error opening db: %v", err)
- }
+ require.NoError(t, err, "Error creating database wrapper")
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
- if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
- t.Fatalf("Error creating tables: %v", err)
- }
+ err = sqlStore.DB.Upgrade(context.TODO())
+ require.NoError(t, err, "Error upgrading database")
userID := id.UserID("@mautrix")
mk, _ := olm.NewPKSigning()
@@ -66,29 +63,25 @@ func TestTrustOwnDevice(t *testing.T) {
DeviceID: "device",
SigningKey: id.Ed25519("deviceKey"),
}
- if m.IsDeviceTrusted(context.TODO(), ownDevice) {
- t.Error("Own device trusted while it shouldn't be")
- }
+ assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be")
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(),
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2")
- if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
- t.Error("Own user not trusted while they should be")
- }
- if !m.IsDeviceTrusted(context.TODO(), ownDevice) {
- t.Error("Own device not trusted while it should be")
- }
+ trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID)
+ require.NoError(t, err, "Error checking if own user is trusted")
+ assert.True(t, trusted, "Own user not trusted while they should be")
+ assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be")
}
func TestTrustOtherUser(t *testing.T) {
m := getOlmMachine(t)
otherUser := id.UserID("@user")
- if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
- t.Error("Other user trusted while they shouldn't be")
- }
+ trusted, err := m.IsUserTrusted(context.TODO(), otherUser)
+ require.NoError(t, err, "Error checking if other user is trusted")
+ assert.False(t, trusted, "Other user trusted while they shouldn't be")
theirMasterKey, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
@@ -100,16 +93,16 @@ func TestTrustOtherUser(t *testing.T) {
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig")
- if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
- t.Error("Other user trusted before their master key has been signed with our user-signing key")
- }
+ trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
+ require.NoError(t, err, "Error checking if other user is trusted")
+ assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
- if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
- t.Error("Other user not trusted while they should be")
- }
+ trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
+ require.NoError(t, err, "Error checking if other user is trusted")
+ assert.True(t, trusted, "Other user not trusted while they should be")
}
func TestTrustOtherDevice(t *testing.T) {
@@ -120,12 +113,11 @@ func TestTrustOtherDevice(t *testing.T) {
DeviceID: "theirDevice",
SigningKey: id.Ed25519("theirDeviceKey"),
}
- if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
- t.Error("Other user trusted while they shouldn't be")
- }
- if m.IsDeviceTrusted(context.TODO(), theirDevice) {
- t.Error("Other device trusted while it shouldn't be")
- }
+
+ trusted, err := m.IsUserTrusted(context.TODO(), otherUser)
+ require.NoError(t, err, "Error checking if other user is trusted")
+ assert.False(t, trusted, "Other user trusted while they shouldn't be")
+ assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be")
theirMasterKey, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
@@ -137,21 +129,17 @@ func TestTrustOtherDevice(t *testing.T) {
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
- if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
- t.Error("Other user not trusted while they should be")
- }
+ trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
+ require.NoError(t, err, "Error checking if other user is trusted")
+ assert.True(t, trusted, "Other user not trusted while they should be")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(),
otherUser, theirMasterKey.PublicKey(), "sig3")
- if m.IsDeviceTrusted(context.TODO(), theirDevice) {
- t.Error("Other device trusted before it has been signed with user's SSK")
- }
+ assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
otherUser, theirSSK.PublicKey(), "sig4")
- if !m.IsDeviceTrusted(context.TODO(), theirDevice) {
- t.Error("Other device not trusted while it should be")
- }
+ assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK")
}
diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go
index 56f8b484..b62dc128 100644
--- a/crypto/cryptohelper/cryptohelper.go
+++ b/crypto/cryptohelper/cryptohelper.go
@@ -225,13 +225,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted)
}
- if helper.client.SetAppServiceDeviceID {
- err = helper.mach.ShareKeys(ctx, -1)
- if err != nil {
- return fmt.Errorf("failed to share keys: %w", err)
- }
- }
-
return nil
}
@@ -268,24 +261,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
if !ok || len(device.Keys) == 0 {
if isShared {
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
- } else {
- helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
- return nil
}
+ helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys")
+ err = helper.mach.ShareKeys(ctx, -1)
+ if err != nil {
+ return fmt.Errorf("failed to share keys: %w", err)
+ }
+ return nil
} else if !isShared {
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
- }
- if !isShared {
- helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
} else {
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
+ return nil
}
- return nil
}
-var NoSessionFound = crypto.NoSessionFound
+var NoSessionFound = crypto.ErrNoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
@@ -304,24 +297,14 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even
ctx = log.WithContext(ctx)
decrypted, err := helper.Decrypt(ctx, evt)
- if errors.Is(err, NoSessionFound) {
- log.Debug().
- Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
- Msg("Couldn't find session, waiting for keys to arrive...")
- if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
- log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
- decrypted, err = helper.Decrypt(ctx, evt)
- } else {
- go helper.waitLongerForSession(ctx, log, evt)
- return
- }
- }
- if err != nil {
+ if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" {
+ go helper.waitForSession(ctx, evt)
+ } else if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
- return
+ } else {
+ helper.postDecrypt(ctx, decrypted)
}
- helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
@@ -362,10 +345,33 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
}
}
-func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
+func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) {
+ log := zerolog.Ctx(ctx)
+ content := evt.Content.AsEncrypted()
+
+ log.Debug().
+ Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
+ Msg("Couldn't find session, waiting for keys to arrive...")
+ if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
+ log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
+ decrypted, err := helper.Decrypt(ctx, evt)
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to decrypt event")
+ helper.DecryptErrorCallback(evt, err)
+ } else {
+ helper.postDecrypt(ctx, decrypted)
+ }
+ } else {
+ go helper.waitLongerForSession(ctx, evt)
+ }
+}
+
+func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) {
+ log := zerolog.Ctx(ctx)
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
+ //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@@ -413,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
defer helper.lock.RUnlock()
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
if err != nil {
- if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
+ if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
return
}
helper.log.Debug().
diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go
index 47279474..457d5a0c 100644
--- a/crypto/decryptmegolm.go
+++ b/crypto/decryptmegolm.go
@@ -24,13 +24,23 @@ import (
)
var (
- IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
- NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
- DuplicateMessageIndex = errors.New("duplicate megolm message index")
- WrongRoom = errors.New("encrypted megolm event is not intended for this room")
- DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
- SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
- RatchetError = errors.New("failed to ratchet session after use")
+ ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
+ ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
+ ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
+ ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
+ ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
+ ErrRatchetError = errors.New("failed to ratchet session after use")
+ ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
+ NoSessionFound = ErrNoSessionFound
+ DuplicateMessageIndex = ErrDuplicateMessageIndex
+ WrongRoom = ErrWrongRoom
+ DeviceKeyMismatch = ErrDeviceKeyMismatch
+ RatchetError = ErrRatchetError
)
type megolmEvent struct {
@@ -45,13 +55,30 @@ var (
relatesToTopLevelPath = exgjson.Path("content", "m.relates_to")
)
+const sessionIDLength = 43
+
+func validateCiphertextCharacters(ciphertext []byte) bool {
+ for _, b := range ciphertext {
+ if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' {
+ return false
+ }
+ }
+ return true
+}
+
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
- return nil, IncorrectEncryptedContentType
+ return nil, ErrIncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
- return nil, UnsupportedAlgorithm
+ return nil, ErrUnsupportedAlgorithm
+ } else if len(content.MegolmCiphertext) < 74 {
+ return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext))
+ } else if len(content.SessionID) != sessionIDLength {
+ return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID))
+ } else if !validateCiphertextCharacters(content.MegolmCiphertext) {
+ return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload)
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
@@ -97,7 +124,13 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
- return nil, DeviceKeyMismatch
+ log.Debug().
+ Stringer("session_sender_key", sess.SenderKey).
+ Stringer("device_sender_key", device.IdentityKey).
+ Stringer("session_signing_key", sess.SigningKey).
+ Stringer("device_signing_key", device.SigningKey).
+ Msg("Device keys don't match keys in session, marking as untrusted")
+ trustLevel = id.TrustStateDeviceKeyMismatch
} else {
trustLevel, err = mach.ResolveTrustContext(ctx, device)
if err != nil {
@@ -147,7 +180,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
if err != nil {
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
} else if megolmEvt.RoomID != encryptionRoomID {
- return nil, WrongRoom
+ return nil, ErrWrongRoom
}
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
megolmEvt.Type.Class = event.StateEventType
@@ -180,6 +213,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
TrustSource: device,
ForwardedKeys: forwardedKeys,
WasEncrypted: true,
+ EventSource: evt.Mautrix.EventSource | event.SourceDecrypted,
ReceivedAt: evt.Mautrix.ReceivedAt,
},
}, nil
@@ -201,19 +235,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
if decodeErr != nil {
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
- return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
+ return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
- return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
} else if !ok {
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
- return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
}
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
- return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
}
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
@@ -224,13 +258,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
- return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
- } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
- return sess, nil, 0, SenderKeyMismatch
+ return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
- if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
+ if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
}
@@ -238,7 +270,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
- return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
+ return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
}
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
@@ -290,24 +322,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}
diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go
index b737e4e1..aea5e6dc 100644
--- a/crypto/decryptolm.go
+++ b/crypto/decryptolm.go
@@ -17,21 +17,36 @@ import (
"time"
"github.com/rs/zerolog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/ptr"
+ "maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var (
- UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
- NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
- UnsupportedOlmMessageType = errors.New("unsupported olm message type")
- DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
- DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
- SenderMismatch = errors.New("mismatched sender in olm payload")
- RecipientMismatch = errors.New("mismatched recipient in olm payload")
- RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
- ErrDuplicateMessage = errors.New("duplicate olm message")
+ ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
+ ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
+ ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
+ ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
+ ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
+ ErrSenderMismatch = errors.New("mismatched sender in olm payload")
+ ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
+ ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
+ ErrDuplicateMessage = errors.New("duplicate olm message")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ UnsupportedAlgorithm = ErrUnsupportedAlgorithm
+ NotEncryptedForMe = ErrNotEncryptedForMe
+ UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
+ DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
+ DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
+ SenderMismatch = ErrSenderMismatch
+ RecipientMismatch = ErrRecipientMismatch
+ RecipientKeyMismatch = ErrRecipientKeyMismatch
)
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
@@ -53,13 +68,13 @@ type DecryptedOlmEvent struct {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
- return nil, IncorrectEncryptedContentType
+ return nil, ErrIncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmOlmV1 {
- return nil, UnsupportedAlgorithm
+ return nil, ErrUnsupportedAlgorithm
}
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
if !ok {
- return nil, NotEncryptedForMe
+ return nil, ErrNotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
@@ -75,7 +90,7 @@ type OlmEventKeys struct {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
- return nil, UnsupportedOlmMessageType
+ return nil, ErrUnsupportedOlmMessageType
}
log := mach.machOrContextLog(ctx).With().
@@ -99,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
olmEvt.Type.Class = evt.Type.Class
if evt.Sender != olmEvt.Sender {
- return nil, SenderMismatch
+ return nil, ErrSenderMismatch
} else if mach.Client.UserID != olmEvt.Recipient {
- return nil, RecipientMismatch
+ return nil, ErrRecipientMismatch
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
- return nil, RecipientKeyMismatch
+ return nil, ErrRecipientKeyMismatch
}
if len(olmEvt.Content.VeryRaw) > 0 {
@@ -119,6 +134,9 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
func olmMessageHash(ciphertext string) ([32]byte, error) {
+ if ciphertext == "" {
+ return [32]byte{}, fmt.Errorf("empty ciphertext")
+ }
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
return sha256.Sum256(ciphertextBytes), err
}
@@ -148,7 +166,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
if err != nil {
- if err == DecryptionFailedWithMatchingSession {
+ if err == ErrDecryptionFailedWithMatchingSession {
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
@@ -166,9 +184,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(log, sender, senderKey)
- return nil, DecryptionFailedForNormalMessage
+ return nil, ErrDecryptionFailedForNormalMessage
}
+ accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
@@ -180,6 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
log.Debug().
Hex("ciphertext_hash", ciphertextHash[:]).
+ Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("olm_session_description", session.Describe()).
Msg("Created inbound olm session")
ctx = log.WithContext(ctx)
@@ -188,6 +208,19 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err = session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
+ log.Debug().
+ Hex("ciphertext_hash", ciphertextHash[:]).
+ Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
+ Str("ciphertext", ciphertext).
+ Str("olm_session_description", session.Describe()).
+ Msg("DEBUG: Failed to decrypt prekey olm message with newly created session")
+ err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup)
+ if err2 != nil {
+ log.Debug().Err(err2).Msg("Goolm confirmed decryption failure")
+ } else {
+ log.Warn().Msg("Goolm decryption was successful after libolm failure?")
+ }
+
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
}
@@ -205,6 +238,23 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
return plaintext, nil
}
+func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error {
+ acc, err := account.AccountFromPickled(accountBackup, []byte("tmp"))
+ if err != nil {
+ return fmt.Errorf("failed to unpickle olm account: %w", err)
+ }
+ sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext)
+ if err != nil {
+ return fmt.Errorf("failed to create inbound session: %w", err)
+ }
+ _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey)
+ if err != nil {
+ // This is the expected result if libolm failed
+ return fmt.Errorf("failed to decrypt with new session: %w", err)
+ }
+ return nil
+}
+
const MaxOlmSessionsPerDevice = 5
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
@@ -263,10 +313,11 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
if err != nil {
log.Warn().Err(err).
Hex("ciphertext_hash", ciphertextHash[:]).
+ Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
- return nil, DecryptionFailedWithMatchingSession
+ return nil, ErrDecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
@@ -306,10 +357,10 @@ const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
log = log.With().Str("action", "unwedge olm session").Logger()
- ctx := log.WithContext(mach.BackgroundCtx)
+ ctx := log.WithContext(mach.backgroundCtx)
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
- delta := time.Now().Sub(prevUnwedge)
+ delta := time.Since(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
log.Debug().
Str("previous_recreation", delta.String()).
@@ -340,7 +391,10 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
return
}
- log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
+ log.Debug().
+ Time("last_created", lastCreatedAt).
+ Stringer("device_id", deviceIdentity.DeviceID).
+ Msg("Creating new Olm session")
mach.devicesToUnwedgeLock.Lock()
mach.devicesToUnwedge[senderKey] = true
mach.devicesToUnwedgeLock.Unlock()
diff --git a/crypto/devicelist.go b/crypto/devicelist.go
index a2116ed5..f0d2b129 100644
--- a/crypto/devicelist.go
+++ b/crypto/devicelist.go
@@ -22,14 +22,23 @@ import (
)
var (
- MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
- MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
- MismatchingSigningKey = errors.New("received update for device with different signing key")
- NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
- NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
- InvalidKeySignature = errors.New("invalid signature on device keys")
+ ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
+ ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
+ ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
+ ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
+ ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
+ ErrInvalidKeySignature = errors.New("invalid signature on device keys")
+ ErrUserNotTracked = errors.New("user is not tracked")
+)
- ErrUserNotTracked = errors.New("user is not tracked")
+// Deprecated: use variables prefixed with Err
+var (
+ MismatchingDeviceID = ErrMismatchingDeviceID
+ MismatchingUserID = ErrMismatchingUserID
+ MismatchingSigningKey = ErrMismatchingSigningKey
+ NoSigningKeyFound = ErrNoSigningKeyFound
+ NoIdentityKeyFound = ErrNoIdentityKeyFound
+ InvalidKeySignature = ErrInvalidKeySignature
)
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
@@ -206,7 +215,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
for userID, devices := range resp.DeviceKeys {
- log := log.With().Str("user_id", userID.String()).Logger()
+ log := log.With().Stringer("user_id", userID).Logger()
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
@@ -222,7 +231,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
Msg("Updating devices in store")
changed := false
for deviceID, deviceKeys := range devices {
- log := log.With().Str("device_id", deviceID.String()).Logger()
+ log := log.With().Stringer("device_id", deviceID).Logger()
existing, ok := existingDevices[deviceID]
if !ok {
// New device
@@ -270,7 +279,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
}
}
for userID := range req.DeviceKeys {
- log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
+ log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user")
}
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
@@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
if deviceID != deviceKeys.DeviceID {
- return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
+ return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
} else if userID != deviceKeys.UserID {
- return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
+ return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
}
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
if signingKey == "" {
- return nil, NoSigningKeyFound
+ return nil, ErrNoSigningKeyFound
} else if identityKey == "" {
- return nil, NoIdentityKeyFound
+ return nil, ErrNoIdentityKeyFound
}
if existing != nil && existing.SigningKey != signingKey {
- return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
+ return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
}
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
if err != nil {
return existing, fmt.Errorf("failed to verify signature: %w", err)
} else if !ok {
- return existing, InvalidKeySignature
+ return existing, ErrInvalidKeySignature
}
name, ok := deviceKeys.Unsigned["device_display_name"].(string)
diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go
index 804e15de..88f9c8d4 100644
--- a/crypto/encryptmegolm.go
+++ b/crypto/encryptmegolm.go
@@ -25,8 +25,12 @@ import (
)
var (
- AlreadyShared = errors.New("group session already shared")
- NoGroupSession = errors.New("no group session created")
+ ErrNoGroupSession = errors.New("no group session created")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ NoGroupSession = ErrNoGroupSession
)
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
@@ -42,7 +46,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T {
return &result
}
-func getRelatesTo(content any) *event.RelatesTo {
+func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
contentJSON, ok := content.(json.RawMessage)
if ok {
return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to")
@@ -55,7 +59,7 @@ func getRelatesTo(content any) *event.RelatesTo {
if ok {
return relatable.OptionalGetRelatesTo()
}
- return nil
+ return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to")
}
func getMentions(content any) *event.Mentions {
@@ -83,15 +87,20 @@ type rawMegolmEvent struct {
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
func IsShareError(err error) bool {
- return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
+ return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
}
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
+ if len(ciphertext) == 0 {
+ return 0, fmt.Errorf("empty ciphertext")
+ }
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
+ } else if len(decoded) < 2+binary.MaxVarintLen64 {
+ return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded))
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
@@ -121,7 +130,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
- return nil, NoGroupSession
+ return nil, ErrNoGroupSession
}
plaintext, err := json.Marshal(&rawMegolmEvent{
RoomID: roomID,
@@ -159,12 +168,21 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
- RelatesTo: getRelatesTo(content),
+ RelatesTo: getRelatesTo(content, plaintext),
// These are deprecated
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}
+ if mach.MSC4392Relations && encrypted.RelatesTo != nil {
+ // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content.
+ // Other relations like threads are still left unencrypted.
+ encrypted.RelatesTo.InReplyTo = nil
+ encrypted.RelatesTo.IsFallingBack = false
+ if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" {
+ encrypted.RelatesTo = nil
+ }
+ }
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
@@ -209,7 +227,8 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
- return AlreadyShared
+ mach.machOrContextLog(ctx).Debug().Stringer("room_id", roomID).Msg("Not re-sharing group session, already shared")
+ return nil
}
log := mach.machOrContextLog(ctx).With().
Str("room_id", roomID.String()).
@@ -233,7 +252,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
var fetchKeysForUsers []id.UserID
for _, userID := range users {
- log := log.With().Str("target_user_id", userID.String()).Logger()
+ log := log.With().Stringer("target_user_id", userID).Logger()
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Err(err).Msg("Failed to get devices of user")
@@ -305,7 +324,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
toDeviceWithheld.Messages[userID] = withheld
}
- log := log.With().Str("target_user_id", userID.String()).Logger()
+ log := log.With().Stringer("target_user_id", userID).Logger()
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
log.Debug().
@@ -351,26 +370,19 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
+ logUsers := zerolog.Dict()
for userID, sessions := range olmSessions {
if len(sessions) == 0 {
continue
}
+ logDevices := zerolog.Dict()
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
- log.Trace().
- Stringer("target_user_id", userID).
- Stringer("target_device_id", deviceID).
- Stringer("target_identity_key", device.identity.IdentityKey).
- Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
+ logDevices.Str(string(deviceID), string(device.identity.IdentityKey))
deviceCount++
- log.Debug().
- Stringer("target_user_id", userID).
- Stringer("target_device_id", deviceID).
- Stringer("target_identity_key", device.identity.IdentityKey).
- Msg("Encrypted group session for device")
if !mach.DisableSharedGroupSessionTracking {
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
if err != nil {
@@ -384,11 +396,13 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
}
}
}
+ logUsers.Dict(string(userID), logDevices)
}
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
+ Dict("destination_map", logUsers).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
return err
diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go
index 80b76dc5..765307af 100644
--- a/crypto/encryptolm.go
+++ b/crypto/encryptolm.go
@@ -96,15 +96,19 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
panic(err)
}
log := mach.machOrContextLog(ctx)
- log.Debug().
- Str("recipient_identity_key", recipient.IdentityKey.String()).
- Str("olm_session_id", session.ID().String()).
- Str("olm_session_description", session.Describe()).
- Msg("Encrypting olm message")
msgType, ciphertext, err := session.Encrypt(plaintext)
if err != nil {
panic(err)
}
+ ciphertextStr := string(ciphertext)
+ ciphertextHash, _ := olmMessageHash(ciphertextStr)
+ log.Debug().
+ Stringer("event_type", evtType).
+ Str("recipient_identity_key", recipient.IdentityKey.String()).
+ Str("olm_session_id", session.ID().String()).
+ Str("olm_session_description", session.Describe()).
+ Hex("ciphertext_hash", ciphertextHash[:]).
+ Msg("Encrypted olm message")
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
@@ -115,7 +119,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
OlmCiphertext: event.OlmCiphertexts{
recipient.IdentityKey: {
Type: msgType,
- Body: string(ciphertext),
+ Body: ciphertextStr,
},
},
}
diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go
index 4da08a73..b48843a4 100644
--- a/crypto/goolm/account/account.go
+++ b/crypto/goolm/account/account.go
@@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
if err != nil {
return err
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
- return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
return err
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go
index e1c9b452..d0dec5f0 100644
--- a/crypto/goolm/account/account_test.go
+++ b/crypto/goolm/account/account_test.go
@@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
account, err := account.NewAccount()
assert.NoError(t, err)
err = account.Unpickle(pickled, pickleKey)
- assert.ErrorIs(t, err, olm.ErrBadVersion)
+ assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
}
func TestLoopback(t *testing.T) {
diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go
index c6b9e523..ec392d7e 100644
--- a/crypto/goolm/account/register.go
+++ b/crypto/goolm/account/register.go
@@ -10,7 +10,7 @@ import (
"maunium.net/go/mautrix/crypto/olm"
)
-func init() {
+func Register() {
olm.InitNewAccount = func() (olm.Account, error) {
return NewAccount()
}
diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go
index e9759501..6e42d886 100644
--- a/crypto/goolm/crypto/curve25519.go
+++ b/crypto/goolm/crypto/curve25519.go
@@ -53,6 +53,7 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
// SharedSecret returns the shared secret between the key pair and the given public key.
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
+ // Note: the standard library checks that the output is non-zero
return c.PrivateKey.SharedSecret(pubKey)
}
diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go
index 9039c126..2550f15e 100644
--- a/crypto/goolm/crypto/curve25519_test.go
+++ b/crypto/goolm/crypto/curve25519_test.go
@@ -25,6 +25,8 @@ func TestCurve25519(t *testing.T) {
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
assert.NoError(t, err)
assert.Equal(t, fromPrivate, firstKeypair)
+ _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength))
+ assert.Error(t, err)
}
func TestCurve25519Case1(t *testing.T) {
diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go
index 061a052a..58ee26f7 100644
--- a/crypto/goolm/goolmbase64/base64.go
+++ b/crypto/goolm/goolmbase64/base64.go
@@ -4,7 +4,8 @@ import (
"encoding/base64"
)
-// Deprecated: base64.RawStdEncoding should be used directly
+// These methods should only be used for raw byte operations, never with string conversion
+
func Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
@@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) {
return decoded[:writtenBytes], nil
}
-// Deprecated: base64.RawStdEncoding should be used directly
func Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)
diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go
index 308e472c..f765391f 100644
--- a/crypto/goolm/libolmpickle/picklejson.go
+++ b/crypto/goolm/libolmpickle/picklejson.go
@@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
}
}
if decrypted[0] != pickleVersion {
- return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
+ return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {
diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go
index a71cf302..b06756a9 100644
--- a/crypto/goolm/message/decoder.go
+++ b/crypto/goolm/message/decoder.go
@@ -3,6 +3,9 @@ package message
import (
"bytes"
"encoding/binary"
+ "fmt"
+
+ "maunium.net/go/mautrix/crypto/olm"
)
type Decoder struct {
@@ -20,6 +23,8 @@ func (d *Decoder) ReadVarInt() (uint64, error) {
func (d *Decoder) ReadVarBytes() ([]byte, error) {
if n, err := d.ReadVarInt(); err != nil {
return nil, err
+ } else if n > uint64(d.Len()) {
+ return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available())
} else {
out := make([]byte, n)
_, err = d.Read(out)
diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go
index c2a43b1f..c83540c1 100644
--- a/crypto/goolm/message/group_message.go
+++ b/crypto/goolm/message/group_message.go
@@ -2,10 +2,12 @@ package message
import (
"bytes"
+ "fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
+ "maunium.net/go/mautrix/crypto/olm"
)
const (
@@ -36,6 +38,9 @@ func (r *GroupMessage) Decode(input []byte) (err error) {
if err != nil {
return
}
+ if r.Version != protocolVersion {
+ return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
+ }
for {
// Read Key
diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go
index 8bb6e0cd..b161a2d1 100644
--- a/crypto/goolm/message/message.go
+++ b/crypto/goolm/message/message.go
@@ -2,10 +2,12 @@ package message
import (
"bytes"
+ "fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
+ "maunium.net/go/mautrix/crypto/olm"
)
const (
@@ -40,6 +42,9 @@ func (r *Message) Decode(input []byte) (err error) {
if err != nil {
return
}
+ if r.Version != protocolVersion {
+ return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
+ }
for {
// Read Key
diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go
index 22ebf9c3..4e3d495d 100644
--- a/crypto/goolm/message/prekey_message.go
+++ b/crypto/goolm/message/prekey_message.go
@@ -1,6 +1,7 @@
package message
import (
+ "fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/crypto"
@@ -22,6 +23,11 @@ type PreKeyMessage struct {
Message []byte `json:"message"`
}
+// TODO deduplicate constant with one in session/olm_session.go
+const (
+ protocolVersion = 0x3
+)
+
// Decodes decodes the input and populates the corresponding fileds.
func (r *PreKeyMessage) Decode(input []byte) (err error) {
r.Version = 0
@@ -41,6 +47,9 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) {
}
return
}
+ if r.Version != protocolVersion {
+ return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
+ }
for {
// Read Key
diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go
index 956868b2..d58dbb21 100644
--- a/crypto/goolm/message/session_export.go
+++ b/crypto/goolm/message/session_export.go
@@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error {
return fmt.Errorf("decrypt: %w", olm.ErrBadInput)
}
if input[0] != sessionExportVersion {
- return fmt.Errorf("decrypt: %w", olm.ErrBadVersion)
+ return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go
index 16240945..d04ef15a 100644
--- a/crypto/goolm/message/session_sharing.go
+++ b/crypto/goolm/message/session_sharing.go
@@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
}
s.PublicKey = publicKey
if input[0] != sessionSharingVersion {
- return fmt.Errorf("verify: %w", olm.ErrBadVersion)
+ return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go
index afb01f74..cdb20eb1 100644
--- a/crypto/goolm/pk/decryption.go
+++ b/crypto/goolm/pk/decryption.go
@@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error {
if pickledVersion == decryptionPickleVersionLibOlm {
return a.KeyPair.UnpickleLibOlm(decoder)
} else {
- return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm)
+ return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm)
}
}
diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go
index 23f67ddf..2897d9b0 100644
--- a/crypto/goolm/pk/encryption.go
+++ b/crypto/goolm/pk/encryption.go
@@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat
return nil, nil, err
}
cipher, err := aessha2.NewAESSHA2(sharedSecret, nil)
+ if err != nil {
+ return nil, nil, err
+ }
ciphertext, err = cipher.Encrypt(plaintext)
if err != nil {
return nil, nil, err
diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go
index b7af6a5b..0e27b568 100644
--- a/crypto/goolm/pk/register.go
+++ b/crypto/goolm/pk/register.go
@@ -8,7 +8,7 @@ package pk
import "maunium.net/go/mautrix/crypto/olm"
-func init() {
+func Register() {
olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) {
return NewSigningFromSeed(seed)
}
diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go
index 229c9bd2..9901ada8 100644
--- a/crypto/goolm/ratchet/olm.go
+++ b/crypto/goolm/ratchet/olm.go
@@ -142,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
return nil, err
}
if message.Version != protocolVersion {
- return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
+ return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion)
}
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go
index 80ed206b..800f567f 100644
--- a/crypto/goolm/register.go
+++ b/crypto/goolm/register.go
@@ -7,19 +7,23 @@
package goolm
import (
- // Need to import these subpackages to ensure they are registered
- _ "maunium.net/go/mautrix/crypto/goolm/account"
- _ "maunium.net/go/mautrix/crypto/goolm/pk"
- _ "maunium.net/go/mautrix/crypto/goolm/session"
-
+ "maunium.net/go/mautrix/crypto/goolm/account"
+ "maunium.net/go/mautrix/crypto/goolm/pk"
+ "maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/olm"
)
-func init() {
+func Register() {
+ olm.Driver = "goolm"
+
olm.GetVersion = func() (major, minor, patch uint8) {
return 3, 2, 15
}
olm.SetPickleKeyImpl = func(key []byte) {
panic("gob and json encoding is deprecated and not supported with goolm")
}
+
+ account.Register()
+ pk.Register()
+ session.Register()
}
diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go
index 80dd71cc..7ccbd26d 100644
--- a/crypto/goolm/session/megolm_inbound_session.go
+++ b/crypto/goolm/session/megolm_inbound_session.go
@@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
}
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
// the counter is before our initial ratchet - we can't decode this
- return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable)
+ return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex)
}
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
copiedRatchet := o.InitialRatchet
@@ -126,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error)
return nil, 0, err
}
if msg.Version != protocolVersion {
- return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
+ return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion)
}
if msg.Ciphertext == nil || !msg.HasMessageIndex {
return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
@@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error {
return err
}
if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 {
- return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil {
diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go
index 2b8e1c84..7f923534 100644
--- a/crypto/goolm/session/megolm_outbound_session.go
+++ b/crypto/goolm/session/megolm_outbound_session.go
@@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
- if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
- return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ if err != nil {
+ return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err)
+ } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
+ return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil {
return err
diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go
index b99ab630..a1cb8d66 100644
--- a/crypto/goolm/session/olm_session.go
+++ b/crypto/goolm/session/olm_session.go
@@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received
msg := message.Message{}
err = msg.Decode(oneTimeMsg.Message)
if err != nil {
- return nil, fmt.Errorf("Message decode: %w", err)
+ return nil, fmt.Errorf("message decode: %w", err)
}
if len(msg.RatchetKey) == 0 {
- return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat)
+ return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat)
}
//Init Ratchet
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
@@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID {
copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey)
hash := sha256.Sum256(message)
- res := id.SessionID(goolmbase64.Encode(hash[:]))
+ res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:]))
return res
}
@@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput)
}
- decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext))
+ decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext)
if err != nil {
return nil, err
}
@@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error {
func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
+ if err != nil {
+ return fmt.Errorf("unpickle olmSession: failed to read version: %w", err)
+ }
var includesChainIndex bool
switch pickledVersion {
@@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
case uint32(0x80000001):
includesChainIndex = true
default:
- return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if o.ReceivedMessage, err = decoder.ReadBool(); err != nil {
diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go
index 09ed42d4..b95a44ac 100644
--- a/crypto/goolm/session/register.go
+++ b/crypto/goolm/session/register.go
@@ -10,11 +10,11 @@ import (
"maunium.net/go/mautrix/crypto/olm"
)
-func init() {
+func Register() {
// Inbound Session
olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
if len(key) == 0 {
key = []byte(" ")
@@ -23,13 +23,13 @@ func init() {
}
olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
return NewMegolmInboundSession(sessionKey)
}
olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
return NewMegolmInboundSessionFromExport(sessionKey)
}
@@ -40,7 +40,7 @@ func init() {
// Outbound Session
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
lenKey := len(key)
if lenKey == 0 {
diff --git a/crypto/keybackup.go b/crypto/keybackup.go
index d8b3d715..7b3c30db 100644
--- a/crypto/keybackup.go
+++ b/crypto/keybackup.go
@@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context,
// ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private
// key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing
// from a verified device belonging to the same user."
- megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
- if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
- log.Debug().Msg("key backup is trusted based on derived public key")
- return versionInfo, nil
- } else {
+ if megolmBackupKey != nil {
+ megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
+ if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
+ log.Debug().Msg("Key backup is trusted based on derived public key")
+ return versionInfo, nil
+ }
log.Debug().
Stringer("expected_key", megolmBackupDerivedPublicKey).
Stringer("actual_key", versionInfo.AuthData.PublicKey).
@@ -199,13 +200,14 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving(
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,
- ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
+ ForwardingChains: keyBackupData.ForwardingKeyChain,
id: sessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
KeyBackupVersion: version,
+ KeySource: id.KeySourceBackup,
}, nil
}
diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go
index 47616a20..fd6f105d 100644
--- a/crypto/keyexport_test.go
+++ b/crypto/keyexport_test.go
@@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) {
))
data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess})
assert.NoError(t, err)
- assert.Len(t, data, 836)
+ assert.Len(t, data, 893)
}
diff --git a/crypto/keyimport.go b/crypto/keyimport.go
index 36ad6b9c..3ffc74a5 100644
--- a/crypto/keyimport.go
+++ b/crypto/keyimport.go
@@ -108,19 +108,20 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
return false, ErrMismatchingExportedSessionID
}
igs := &InboundGroupSession{
- Internal: igsInternal,
- SigningKey: session.SenderClaimedKeys.Ed25519,
- SenderKey: session.SenderKey,
- RoomID: session.RoomID,
- // TODO should we add something here to mark the signing key as unverified like key requests do?
+ Internal: igsInternal,
+ SigningKey: session.SenderClaimedKeys.Ed25519,
+ SenderKey: session.SenderKey,
+ RoomID: session.RoomID,
ForwardingChains: session.ForwardingChains,
-
- ReceivedAt: time.Now().UTC(),
+ KeySource: id.KeySourceImport,
+ ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
firstKnownIndex := igs.Internal.FirstKnownIndex()
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
- // We already have an equivalent or better session in the store, so don't override it.
+ // We already have an equivalent or better session in the store, so don't override it,
+ // but do notify the session received callback just in case.
+ mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex())
return false, nil
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
diff --git a/crypto/keysharing.go b/crypto/keysharing.go
index f1d427af..19a68c87 100644
--- a/crypto/keysharing.go
+++ b/crypto/keysharing.go
@@ -189,6 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
+ KeySource: id.KeySourceForward,
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
@@ -214,6 +215,7 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare
RoomID: request.RoomID,
Algorithm: request.Algorithm,
SessionID: request.SessionID,
+ //lint:ignore SA1019 This is just echoing back the deprecated field
SenderKey: request.SenderKey,
Code: rejection.Code,
Reason: rejection.Reason,
@@ -263,9 +265,14 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev
log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing")
return &KeyShareRejectNoResponse
} else if !isShared {
- // TODO differentiate session not shared with requester vs session not created by this device?
- log.Debug().Msg("Rejecting key request for unshared session")
- return &KeyShareRejectNotRecipient
+ igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID)
+ if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey {
+ log.Debug().Msg("Rejecting key request for unshared session")
+ return &KeyShareRejectNotRecipient
+ }
+ // Note: this case will also happen for redacted sessions and database errors
+ log.Debug().Msg("Rejecting key request for session created by another device")
+ return &KeyShareRejectNoResponse
}
log.Debug().Msg("Accepting key request for shared session")
return nil
@@ -323,7 +330,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
- mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ if sender != mach.Client.UserID {
+ mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ }
} else {
log.Error().Err(err).Msg("Failed to get group session to forward")
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
@@ -331,7 +340,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
return
} else if igs == nil {
log.Error().Msg("Didn't find group session to forward")
- mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ if sender != mach.Client.UserID {
+ mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ }
return
}
if internalID := igs.ID(); internalID != content.Body.SessionID {
@@ -356,7 +367,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
SessionID: igs.ID(),
SessionKey: string(exportedKey),
},
- SenderKey: content.Body.SenderKey,
+ SenderKey: igs.SenderKey,
ForwardingKeyChain: igs.ForwardingChains,
SenderClaimedKey: igs.SigningKey,
},
diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go
index cddce7ce..0350f083 100644
--- a/crypto/libolm/account.go
+++ b/crypto/libolm/account.go
@@ -8,6 +8,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
+ "runtime"
"unsafe"
"github.com/tidwall/gjson"
@@ -22,18 +23,6 @@ type Account struct {
mem []byte
}
-func init() {
- olm.InitNewAccount = func() (olm.Account, error) {
- return NewAccount()
- }
- olm.InitBlankAccount = func() olm.Account {
- return NewBlankAccount()
- }
- olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) {
- return AccountFromPickled(pickled, key)
- }
-}
-
// Ensure that [Account] implements [olm.Account].
var _ olm.Account = (*Account)(nil)
@@ -44,7 +33,7 @@ var _ olm.Account = (*Account)(nil)
// "INVALID_BASE64".
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
a := NewBlankAccount()
return a, a.Unpickle(pickled, key)
@@ -53,7 +42,7 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) {
func NewBlankAccount() *Account {
memory := make([]byte, accountSize())
return &Account{
- int: C.olm_account(unsafe.Pointer(&memory[0])),
+ int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))),
mem: memory,
}
}
@@ -64,12 +53,13 @@ func NewAccount() (*Account, error) {
random := make([]byte, a.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
- panic(olm.NotEnoughGoRandom)
+ panic(olm.ErrNotEnoughGoRandom)
}
ret := C.olm_create_account(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&random[0]),
+ unsafe.Pointer(unsafe.SliceData(random)),
C.size_t(len(random)))
+ runtime.KeepAlive(random)
if ret == errorVal() {
return nil, a.lastError()
} else {
@@ -138,14 +128,14 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
// supplied key.
func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, a.pickleLen())
r := C.olm_pickle_account(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
+ unsafe.Pointer(unsafe.SliceData(pickled)),
C.size_t(len(pickled)))
if r == errorVal() {
return nil, a.lastError()
@@ -155,13 +145,13 @@ func (a *Account) Pickle(key []byte) ([]byte, error) {
func (a *Account) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_account(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
+ unsafe.Pointer(unsafe.SliceData(pickled)),
C.size_t(len(pickled)))
if r == errorVal() {
return a.lastError()
@@ -208,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if a.int == nil {
*a = *NewBlankAccount()
@@ -221,7 +211,7 @@ func (a *Account) IdentityKeysJSON() ([]byte, error) {
identityKeys := make([]byte, a.identityKeysLen())
r := C.olm_account_identity_keys(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&identityKeys[0]),
+ unsafe.Pointer(unsafe.SliceData(identityKeys)),
C.size_t(len(identityKeys)))
if r == errorVal() {
return nil, a.lastError()
@@ -245,15 +235,16 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
// Account.
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
- panic(olm.EmptyInput)
+ panic(olm.ErrEmptyInput)
}
signature := make([]byte, a.signatureLen())
r := C.olm_account_sign(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&message[0]),
+ unsafe.Pointer(unsafe.SliceData(message)),
C.size_t(len(message)),
- unsafe.Pointer(&signature[0]),
+ unsafe.Pointer(unsafe.SliceData(signature)),
C.size_t(len(signature)))
+ runtime.KeepAlive(message)
if r == errorVal() {
panic(a.lastError())
}
@@ -277,8 +268,9 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen())
r := C.olm_account_one_time_keys(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&oneTimeKeysJSON[0]),
- C.size_t(len(oneTimeKeysJSON)))
+ unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)),
+ C.size_t(len(oneTimeKeysJSON)),
+ )
if r == errorVal() {
return nil, a.lastError()
}
@@ -307,13 +299,15 @@ func (a *Account) GenOneTimeKeys(num uint) error {
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
_, err := rand.Read(random)
if err != nil {
- return olm.NotEnoughGoRandom
+ return olm.ErrNotEnoughGoRandom
}
r := C.olm_account_generate_one_time_keys(
(*C.OlmAccount)(a.int),
C.size_t(num),
- unsafe.Pointer(&random[0]),
- C.size_t(len(random)))
+ unsafe.Pointer(unsafe.SliceData(random)),
+ C.size_t(len(random)),
+ )
+ runtime.KeepAlive(random)
if r == errorVal() {
return a.lastError()
}
@@ -325,23 +319,29 @@ func (a *Account) GenOneTimeKeys(num uint) error {
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
random := make([]byte, s.createOutboundRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
- panic(olm.NotEnoughGoRandom)
+ panic(olm.ErrNotEnoughGoRandom)
}
+ theirIdentityKeyCopy := []byte(theirIdentityKey)
+ theirOneTimeKeyCopy := []byte(theirOneTimeKey)
r := C.olm_create_outbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&([]byte(theirIdentityKey)[0])),
- C.size_t(len(theirIdentityKey)),
- unsafe.Pointer(&([]byte(theirOneTimeKey)[0])),
- C.size_t(len(theirOneTimeKey)),
- unsafe.Pointer(&random[0]),
- C.size_t(len(random)))
+ unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
+ C.size_t(len(theirIdentityKeyCopy)),
+ unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)),
+ C.size_t(len(theirOneTimeKeyCopy)),
+ unsafe.Pointer(unsafe.SliceData(random)),
+ C.size_t(len(random)),
+ )
+ runtime.KeepAlive(random)
+ runtime.KeepAlive(theirIdentityKeyCopy)
+ runtime.KeepAlive(theirOneTimeKeyCopy)
if r == errorVal() {
return nil, s.lastError()
}
@@ -357,14 +357,17 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
if len(oneTimeKeyMsg) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
+ oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_create_inbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
- C.size_t(len(oneTimeKeyMsg)))
+ unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
+ C.size_t(len(oneTimeKeyMsgCopy)),
+ )
+ runtime.KeepAlive(oneTimeKeyMsgCopy)
if r == errorVal() {
return nil, s.lastError()
}
@@ -380,16 +383,21 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
+ theirIdentityKeyCopy := []byte(*theirIdentityKey)
+ oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
s := NewBlankSession()
r := C.olm_create_inbound_session_from(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
- unsafe.Pointer(&([]byte(*theirIdentityKey)[0])),
- C.size_t(len(*theirIdentityKey)),
- unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
- C.size_t(len(oneTimeKeyMsg)))
+ unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
+ C.size_t(len(theirIdentityKeyCopy)),
+ unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
+ C.size_t(len(oneTimeKeyMsgCopy)),
+ )
+ runtime.KeepAlive(theirIdentityKeyCopy)
+ runtime.KeepAlive(oneTimeKeyMsgCopy)
if r == errorVal() {
return nil, s.lastError()
}
@@ -402,7 +410,8 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime
func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
r := C.olm_remove_one_time_keys(
(*C.OlmAccount)(a.int),
- (*C.OlmSession)(s.(*Session).int))
+ (*C.OlmSession)(s.(*Session).int),
+ )
if r == errorVal() {
return a.lastError()
}
diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go
index 9ca415ee..6fb5512b 100644
--- a/crypto/libolm/error.go
+++ b/crypto/libolm/error.go
@@ -11,21 +11,21 @@ import (
)
var errorMap = map[string]error{
- "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom,
- "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall,
- "BAD_MESSAGE_VERSION": olm.BadMessageVersion,
- "BAD_MESSAGE_FORMAT": olm.BadMessageFormat,
- "BAD_MESSAGE_MAC": olm.BadMessageMAC,
- "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID,
- "INVALID_BASE64": olm.InvalidBase64,
- "BAD_ACCOUNT_KEY": olm.BadAccountKey,
- "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion,
- "CORRUPTED_PICKLE": olm.CorruptedPickle,
- "BAD_SESSION_KEY": olm.BadSessionKey,
- "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex,
- "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle,
- "BAD_SIGNATURE": olm.BadSignature,
- "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall,
+ "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom,
+ "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall,
+ "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion,
+ "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat,
+ "BAD_MESSAGE_MAC": olm.ErrBadMAC,
+ "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID,
+ "INVALID_BASE64": olm.ErrLibolmInvalidBase64,
+ "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey,
+ "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion,
+ "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle,
+ "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey,
+ "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex,
+ "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle,
+ "BAD_SIGNATURE": olm.ErrBadSignature,
+ "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall,
}
func convertError(errCode string) error {
diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go
index 1e25748d..8815ac32 100644
--- a/crypto/libolm/inboundgroupsession.go
+++ b/crypto/libolm/inboundgroupsession.go
@@ -7,6 +7,7 @@ import "C"
import (
"bytes"
"encoding/base64"
+ "runtime"
"unsafe"
"maunium.net/go/mautrix/crypto/olm"
@@ -20,21 +21,6 @@ type InboundGroupSession struct {
mem []byte
}
-func init() {
- olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
- return InboundGroupSessionFromPickled(pickled, key)
- }
- olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
- return NewInboundGroupSession(sessionKey)
- }
- olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
- return InboundGroupSessionImport(sessionKey)
- }
- olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession {
- return NewBlankInboundGroupSession()
- }
-}
-
// Ensure that [InboundGroupSession] implements [olm.InboundGroupSession].
var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
@@ -45,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
lenKey := len(key)
if lenKey == 0 {
@@ -62,13 +48,15 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
// "OLM_BAD_SESSION_KEY".
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_init_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(&sessionKey[0]),
- C.size_t(len(sessionKey)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))),
+ C.size_t(len(sessionKey)),
+ )
+ runtime.KeepAlive(sessionKey)
if r == errorVal() {
return nil, s.lastError()
}
@@ -81,13 +69,15 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
// error will be "OLM_BAD_SESSION_KEY".
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_import_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(&sessionKey[0]),
- C.size_t(len(sessionKey)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))),
+ C.size_t(len(sessionKey)),
+ )
+ runtime.KeepAlive(sessionKey)
if r == errorVal() {
return nil, s.lastError()
}
@@ -104,7 +94,7 @@ func inboundGroupSessionSize() uint {
func NewBlankInboundGroupSession() *InboundGroupSession {
memory := make([]byte, inboundGroupSessionSize())
return &InboundGroupSession{
- int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])),
+ int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))),
mem: memory,
}
}
@@ -134,15 +124,17 @@ func (s *InboundGroupSession) pickleLen() uint {
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
- C.size_t(len(pickled)))
+ unsafe.Pointer(unsafe.SliceData(pickled)),
+ C.size_t(len(pickled)),
+ )
+ runtime.KeepAlive(key)
if r == errorVal() {
return nil, s.lastError()
}
@@ -151,16 +143,18 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
} else if len(pickled) == 0 {
- return olm.EmptyInput
+ return olm.ErrEmptyInput
}
r := C.olm_unpickle_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
- C.size_t(len(pickled)))
+ unsafe.Pointer(unsafe.SliceData(pickled)),
+ C.size_t(len(pickled)),
+ )
+ runtime.KeepAlive(key)
if r == errorVal() {
return s.lastError()
}
@@ -206,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankInboundGroupSession()
@@ -223,14 +217,16 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
// will be "BAD_MESSAGE_FORMAT".
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
if len(message) == 0 {
- return 0, olm.EmptyInput
+ return 0, olm.ErrEmptyInput
}
// olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it
- message = bytes.Clone(message)
+ messageCopy := bytes.Clone(message)
r := C.olm_group_decrypt_max_plaintext_length(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(&message[0]),
- C.size_t(len(message)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))),
+ C.size_t(len(messageCopy)),
+ )
+ runtime.KeepAlive(messageCopy)
if r == errorVal() {
return 0, s.lastError()
}
@@ -248,23 +244,24 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
- return nil, 0, olm.EmptyInput
+ return nil, 0, olm.ErrEmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message)
if err != nil {
return nil, 0, err
}
- messageCopy := make([]byte, len(message))
- copy(messageCopy, message)
+ messageCopy := bytes.Clone(message)
plaintext := make([]byte, decryptMaxPlaintextLen)
var messageIndex uint32
r := C.olm_group_decrypt(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(&messageCopy[0]),
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))),
C.size_t(len(messageCopy)),
- (*C.uint8_t)(&plaintext[0]),
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))),
C.size_t(len(plaintext)),
- (*C.uint32_t)(&messageIndex))
+ (*C.uint32_t)(unsafe.Pointer(&messageIndex)),
+ )
+ runtime.KeepAlive(messageCopy)
if r == errorVal() {
return nil, 0, s.lastError()
}
@@ -281,8 +278,9 @@ func (s *InboundGroupSession) ID() id.SessionID {
sessionID := make([]byte, s.sessionIdLen())
r := C.olm_inbound_group_session_id(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(&sessionID[0]),
- C.size_t(len(sessionID)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))),
+ C.size_t(len(sessionID)),
+ )
if r == errorVal() {
panic(s.lastError())
}
@@ -318,9 +316,10 @@ func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
key := make([]byte, s.exportLen())
r := C.olm_export_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(&key[0]),
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))),
C.size_t(len(key)),
- C.uint32_t(messageIndex))
+ C.uint32_t(messageIndex),
+ )
if r == errorVal() {
return nil, s.lastError()
}
diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go
index a21f8d4a..ca5b68f7 100644
--- a/crypto/libolm/outboundgroupsession.go
+++ b/crypto/libolm/outboundgroupsession.go
@@ -7,6 +7,7 @@ import "C"
import (
"crypto/rand"
"encoding/base64"
+ "runtime"
"unsafe"
"maunium.net/go/mautrix/crypto/olm"
@@ -20,18 +21,6 @@ type OutboundGroupSession struct {
mem []byte
}
-func init() {
- olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
- if len(pickled) == 0 {
- return nil, olm.EmptyInput
- }
- s := NewBlankOutboundGroupSession()
- return s, s.Unpickle(pickled, key)
- }
- olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() }
- olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() }
-}
-
// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession].
var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil)
@@ -44,8 +33,10 @@ func NewOutboundGroupSession() (*OutboundGroupSession, error) {
}
r := C.olm_init_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(&random[0]),
- C.size_t(len(random)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))),
+ C.size_t(len(random)),
+ )
+ runtime.KeepAlive(random)
if r == errorVal() {
return nil, s.lastError()
}
@@ -62,7 +53,7 @@ func outboundGroupSessionSize() uint {
func NewBlankOutboundGroupSession() *OutboundGroupSession {
memory := make([]byte, outboundGroupSessionSize())
return &OutboundGroupSession{
- int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])),
+ int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))),
mem: memory,
}
}
@@ -93,15 +84,17 @@ func (s *OutboundGroupSession) pickleLen() uint {
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
- C.size_t(len(pickled)))
+ unsafe.Pointer(unsafe.SliceData(pickled)),
+ C.size_t(len(pickled)),
+ )
+ runtime.KeepAlive(key)
if r == errorVal() {
return nil, s.lastError()
}
@@ -110,14 +103,17 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
- C.size_t(len(pickled)))
+ unsafe.Pointer(unsafe.SliceData(pickled)),
+ C.size_t(len(pickled)),
+ )
+ runtime.KeepAlive(pickled)
+ runtime.KeepAlive(key)
if r == errorVal() {
return s.lastError()
}
@@ -163,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
@@ -187,15 +183,17 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if len(plaintext) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_group_encrypt(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(&plaintext[0]),
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))),
C.size_t(len(plaintext)),
- (*C.uint8_t)(&message[0]),
- C.size_t(len(message)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))),
+ C.size_t(len(message)),
+ )
+ runtime.KeepAlive(plaintext)
if r == errorVal() {
return nil, s.lastError()
}
@@ -212,8 +210,9 @@ func (s *OutboundGroupSession) ID() id.SessionID {
sessionID := make([]byte, s.sessionIdLen())
r := C.olm_outbound_group_session_id(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(&sessionID[0]),
- C.size_t(len(sessionID)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))),
+ C.size_t(len(sessionID)),
+ )
if r == errorVal() {
panic(s.lastError())
}
@@ -236,8 +235,9 @@ func (s *OutboundGroupSession) Key() string {
sessionKey := make([]byte, s.sessionKeyLen())
r := C.olm_outbound_group_session_key(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(&sessionKey[0]),
- C.size_t(len(sessionKey)))
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))),
+ C.size_t(len(sessionKey)),
+ )
if r == errorVal() {
panic(s.lastError())
}
diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go
index db8d35c5..2683cf15 100644
--- a/crypto/libolm/pk.go
+++ b/crypto/libolm/pk.go
@@ -14,6 +14,7 @@ import "C"
import (
"crypto/rand"
"encoding/json"
+ "runtime"
"unsafe"
"github.com/tidwall/sjson"
@@ -34,16 +35,6 @@ type PKSigning struct {
// Ensure that [PKSigning] implements [olm.PKSigning].
var _ olm.PKSigning = (*PKSigning)(nil)
-func init() {
- olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() }
- olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) {
- return NewPKSigningFromSeed(seed)
- }
- olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) {
- return NewPkDecryption(privateKey)
- }
-}
-
func pkSigningSize() uint {
return uint(C.olm_pk_signing_size())
}
@@ -63,7 +54,7 @@ func pkSigningSignatureLength() uint {
func newBlankPKSigning() *PKSigning {
memory := make([]byte, pkSigningSize())
return &PKSigning{
- int: C.olm_pk_signing(unsafe.Pointer(&memory[0])),
+ int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))),
mem: memory,
}
}
@@ -73,9 +64,14 @@ func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) {
p := newBlankPKSigning()
p.clear()
pubKey := make([]byte, pkSigningPublicKeyLength())
- if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int),
- unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
- unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() {
+ r := C.olm_pk_signing_key_from_seed(
+ (*C.OlmPkSigning)(p.int),
+ unsafe.Pointer(unsafe.SliceData(pubKey)),
+ C.size_t(len(pubKey)),
+ unsafe.Pointer(unsafe.SliceData(seed)),
+ C.size_t(len(seed)),
+ )
+ if r == errorVal() {
return nil, p.lastError()
}
p.publicKey = id.Ed25519(pubKey)
@@ -90,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) {
seed := make([]byte, pkSigningSeedLength())
_, err := rand.Read(seed)
if err != nil {
- panic(olm.NotEnoughGoRandom)
+ panic(olm.ErrNotEnoughGoRandom)
}
pk, err := NewPKSigningFromSeed(seed)
return pk, err
@@ -112,8 +108,15 @@ func (p *PKSigning) clear() {
// Sign creates a signature for the given message using this key.
func (p *PKSigning) Sign(message []byte) ([]byte, error) {
signature := make([]byte, pkSigningSignatureLength())
- if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)),
- (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() {
+ r := C.olm_pk_sign(
+ (*C.OlmPkSigning)(p.int),
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))),
+ C.size_t(len(message)),
+ (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))),
+ C.size_t(len(signature)),
+ )
+ runtime.KeepAlive(message)
+ if r == errorVal() {
return nil, p.lastError()
}
return signature, nil
@@ -157,15 +160,21 @@ func pkDecryptionPublicKeySize() uint {
func NewPkDecryption(privateKey []byte) (*PKDecryption, error) {
memory := make([]byte, pkDecryptionSize())
p := &PKDecryption{
- int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])),
+ int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))),
mem: memory,
}
p.clear()
pubKey := make([]byte, pkDecryptionPublicKeySize())
- if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int),
- unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
- unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() {
+ r := C.olm_pk_key_from_private(
+ (*C.OlmPkDecryption)(p.int),
+ unsafe.Pointer(unsafe.SliceData(pubKey)),
+ C.size_t(len(pubKey)),
+ unsafe.Pointer(unsafe.SliceData(privateKey)),
+ C.size_t(len(privateKey)),
+ )
+ runtime.KeepAlive(privateKey)
+ if r == errorVal() {
return nil, p.lastError()
}
p.publicKey = pubKey
@@ -178,14 +187,26 @@ func (p *PKDecryption) PublicKey() id.Curve25519 {
}
func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
- maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext))))
+ maxPlaintextLength := uint(C.olm_pk_max_plaintext_length(
+ (*C.OlmPkDecryption)(p.int),
+ C.size_t(len(ciphertext)),
+ ))
plaintext := make([]byte, maxPlaintextLength)
- size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int),
- unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)),
- unsafe.Pointer(&mac[0]), C.size_t(len(mac)),
- unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)),
- unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext)))
+ size := C.olm_pk_decrypt(
+ (*C.OlmPkDecryption)(p.int),
+ unsafe.Pointer(unsafe.SliceData(ephemeralKey)),
+ C.size_t(len(ephemeralKey)),
+ unsafe.Pointer(unsafe.SliceData(mac)),
+ C.size_t(len(mac)),
+ unsafe.Pointer(unsafe.SliceData(ciphertext)),
+ C.size_t(len(ciphertext)),
+ unsafe.Pointer(unsafe.SliceData(plaintext)),
+ C.size_t(len(plaintext)),
+ )
+ runtime.KeepAlive(ephemeralKey)
+ runtime.KeepAlive(mac)
+ runtime.KeepAlive(ciphertext)
if size == errorVal() {
return nil, p.lastError()
}
diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go
index a423a7d0..ddf84613 100644
--- a/crypto/libolm/register.go
+++ b/crypto/libolm/register.go
@@ -3,19 +3,73 @@ package libolm
// #cgo LDFLAGS: -lolm -lstdc++
// #include
import "C"
-import "maunium.net/go/mautrix/crypto/olm"
+import (
+ "unsafe"
+
+ "maunium.net/go/mautrix/crypto/olm"
+)
var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm")
-func init() {
+func Register() {
+ olm.Driver = "libolm"
+
olm.GetVersion = func() (major, minor, patch uint8) {
C.olm_get_library_version(
- (*C.uint8_t)(&major),
- (*C.uint8_t)(&minor),
- (*C.uint8_t)(&patch))
+ (*C.uint8_t)(unsafe.Pointer(&major)),
+ (*C.uint8_t)(unsafe.Pointer(&minor)),
+ (*C.uint8_t)(unsafe.Pointer(&patch)))
return 3, 2, 15
}
olm.SetPickleKeyImpl = func(key []byte) {
pickleKey = key
}
+
+ olm.InitNewAccount = func() (olm.Account, error) {
+ return NewAccount()
+ }
+ olm.InitBlankAccount = func() olm.Account {
+ return NewBlankAccount()
+ }
+ olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) {
+ return AccountFromPickled(pickled, key)
+ }
+
+ olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) {
+ return SessionFromPickled(pickled, key)
+ }
+ olm.InitNewBlankSession = func() olm.Session {
+ return NewBlankSession()
+ }
+
+ olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() }
+ olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) {
+ return NewPKSigningFromSeed(seed)
+ }
+ olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) {
+ return NewPkDecryption(privateKey)
+ }
+
+ olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
+ return InboundGroupSessionFromPickled(pickled, key)
+ }
+ olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
+ return NewInboundGroupSession(sessionKey)
+ }
+ olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
+ return InboundGroupSessionImport(sessionKey)
+ }
+ olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession {
+ return NewBlankInboundGroupSession()
+ }
+
+ olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
+ if len(pickled) == 0 {
+ return nil, olm.ErrEmptyInput
+ }
+ s := NewBlankOutboundGroupSession()
+ return s, s.Unpickle(pickled, key)
+ }
+ olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() }
+ olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() }
}
diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go
index 4cc22809..1441df26 100644
--- a/crypto/libolm/session.go
+++ b/crypto/libolm/session.go
@@ -23,6 +23,7 @@ import "C"
import (
"crypto/rand"
"encoding/base64"
+ "runtime"
"unsafe"
"maunium.net/go/mautrix/crypto/olm"
@@ -38,15 +39,6 @@ type Session struct {
// Ensure that [Session] implements [olm.Session].
var _ olm.Session = (*Session)(nil)
-func init() {
- olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) {
- return SessionFromPickled(pickled, key)
- }
- olm.InitNewBlankSession = func() olm.Session {
- return NewBlankSession()
- }
-}
-
// sessionSize is the size of a session object in bytes.
func sessionSize() uint {
return uint(C.olm_session_size())
@@ -59,7 +51,7 @@ func sessionSize() uint {
// "INVALID_BASE64".
func SessionFromPickled(pickled, key []byte) (*Session, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
return s, s.Unpickle(pickled, key)
@@ -68,7 +60,7 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) {
func NewBlankSession() *Session {
memory := make([]byte, sessionSize())
return &Session{
- int: C.olm_session(unsafe.Pointer(&memory[0])),
+ int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))),
mem: memory,
}
}
@@ -126,13 +118,16 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
// will be "BAD_MESSAGE_FORMAT".
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
if len(message) == 0 {
- return 0, olm.EmptyInput
+ return 0, olm.ErrEmptyInput
}
+ messageCopy := []byte(message)
r := C.olm_decrypt_max_plaintext_length(
(*C.OlmSession)(s.int),
C.size_t(msgType),
- unsafe.Pointer(C.CString(message)),
- C.size_t(len(message)))
+ unsafe.Pointer(unsafe.SliceData((messageCopy))),
+ C.size_t(len(messageCopy)),
+ )
+ runtime.KeepAlive(messageCopy)
if r == errorVal() {
return 0, s.lastError()
}
@@ -143,15 +138,16 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
// supplied key.
func (s *Session) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_session(
(*C.OlmSession)(s.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
+ unsafe.Pointer(unsafe.SliceData(pickled)),
C.size_t(len(pickled)))
+ runtime.KeepAlive(key)
if r == errorVal() {
panic(s.lastError())
}
@@ -162,14 +158,16 @@ func (s *Session) Pickle(key []byte) ([]byte, error) {
// provided key. This function mutates the input pickled data slice.
func (s *Session) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_session(
(*C.OlmSession)(s.int),
- unsafe.Pointer(&key[0]),
+ unsafe.Pointer(unsafe.SliceData(key)),
C.size_t(len(key)),
- unsafe.Pointer(&pickled[0]),
+ unsafe.Pointer(unsafe.SliceData(pickled)),
C.size_t(len(pickled)))
+ runtime.KeepAlive(pickled)
+ runtime.KeepAlive(key)
if r == errorVal() {
return s.lastError()
}
@@ -215,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankSession()
@@ -229,8 +227,9 @@ func (s *Session) ID() id.SessionID {
sessionID := make([]byte, s.idLen())
r := C.olm_session_id(
(*C.OlmSession)(s.int),
- unsafe.Pointer(&sessionID[0]),
- C.size_t(len(sessionID)))
+ unsafe.Pointer(unsafe.SliceData(sessionID)),
+ C.size_t(len(sessionID)),
+ )
if r == errorVal() {
panic(s.lastError())
}
@@ -257,12 +256,15 @@ func (s *Session) HasReceivedMessage() bool {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
if len(oneTimeKeyMsg) == 0 {
- return false, olm.EmptyInput
+ return false, olm.ErrEmptyInput
}
+ oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_matches_inbound_session(
(*C.OlmSession)(s.int),
- unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]),
- C.size_t(len(oneTimeKeyMsg)))
+ unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
+ C.size_t(len(oneTimeKeyMsgCopy)),
+ )
+ runtime.KeepAlive(oneTimeKeyMsgCopy)
if r == 1 {
return true, nil
} else if r == 0 {
@@ -282,14 +284,19 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
- return false, olm.EmptyInput
+ return false, olm.ErrEmptyInput
}
+ theirIdentityKeyCopy := []byte(theirIdentityKey)
+ oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_matches_inbound_session_from(
(*C.OlmSession)(s.int),
- unsafe.Pointer(&([]byte(theirIdentityKey))[0]),
- C.size_t(len(theirIdentityKey)),
- unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]),
- C.size_t(len(oneTimeKeyMsg)))
+ unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
+ C.size_t(len(theirIdentityKeyCopy)),
+ unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
+ C.size_t(len(oneTimeKeyMsgCopy)),
+ )
+ runtime.KeepAlive(theirIdentityKeyCopy)
+ runtime.KeepAlive(oneTimeKeyMsgCopy)
if r == 1 {
return true, nil
} else if r == 0 {
@@ -318,25 +325,28 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
- return 0, nil, olm.EmptyInput
+ return 0, nil, olm.ErrEmptyInput
}
// Make the slice be at least length 1
random := make([]byte, s.encryptRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
// TODO can we just return err here?
- return 0, nil, olm.NotEnoughGoRandom
+ return 0, nil, olm.ErrNotEnoughGoRandom
}
messageType := s.EncryptMsgType()
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_encrypt(
(*C.OlmSession)(s.int),
- unsafe.Pointer(&plaintext[0]),
+ unsafe.Pointer(unsafe.SliceData(plaintext)),
C.size_t(len(plaintext)),
- unsafe.Pointer(&random[0]),
+ unsafe.Pointer(unsafe.SliceData(random)),
C.size_t(len(random)),
- unsafe.Pointer(&message[0]),
- C.size_t(len(message)))
+ unsafe.Pointer(unsafe.SliceData(message)),
+ C.size_t(len(message)),
+ )
+ runtime.KeepAlive(plaintext)
+ runtime.KeepAlive(random)
if r == errorVal() {
return 0, nil, s.lastError()
}
@@ -352,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
// "BAD_MESSAGE_MAC".
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType)
if err != nil {
@@ -363,10 +373,12 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error)
r := C.olm_decrypt(
(*C.OlmSession)(s.int),
C.size_t(msgType),
- unsafe.Pointer(&(messageCopy)[0]),
+ unsafe.Pointer(unsafe.SliceData(messageCopy)),
C.size_t(len(messageCopy)),
- unsafe.Pointer(&plaintext[0]),
- C.size_t(len(plaintext)))
+ unsafe.Pointer(unsafe.SliceData(plaintext)),
+ C.size_t(len(plaintext)),
+ )
+ runtime.KeepAlive(messageCopy)
if r == errorVal() {
return nil, s.lastError()
}
@@ -383,6 +395,7 @@ func (s *Session) Describe() string {
C.meowlm_session_describe(
(*C.OlmSession)(s.int),
desc,
- C.size_t(maxDescribeSize))
+ C.size_t(maxDescribeSize),
+ )
return C.GoString(desc)
}
diff --git a/crypto/machine.go b/crypto/machine.go
index cac91bf8..fa051f94 100644
--- a/crypto/machine.go
+++ b/crypto/machine.go
@@ -15,10 +15,12 @@ import (
"time"
"github.com/rs/zerolog"
+ "go.mau.fi/util/ptr"
"go.mau.fi/util/exzerolog"
"maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -33,9 +35,11 @@ type OlmMachine struct {
CryptoStore Store
StateStore StateStore
- BackgroundCtx context.Context
+ backgroundCtx context.Context
+ cancelBackgroundCtx context.CancelFunc
PlaintextMentions bool
+ MSC4392Relations bool
AllowEncryptedState bool
// Never ask the server for keys automatically as a side effect during Megolm decryption.
@@ -120,8 +124,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor
CryptoStore: cryptoStore,
StateStore: stateStore,
- BackgroundCtx: context.Background(),
-
SendKeysMinTrust: id.TrustStateUnset,
ShareKeysMinTrust: id.TrustStateCrossSignedTOFU,
@@ -134,6 +136,7 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor
recentlyUnwedged: make(map[id.IdentityKey]time.Time),
secretListeners: make(map[string]chan<- string),
}
+ mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background())
mach.AllowKeyShare = mach.defaultAllowKeyShare
return mach
}
@@ -146,6 +149,11 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
+func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) {
+ mach.cancelBackgroundCtx()
+ mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx)
+}
+
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
@@ -156,9 +164,23 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) {
if mach.account == nil {
mach.account = NewOlmAccount()
}
+ zerolog.Ctx(ctx).Debug().
+ Str("machine_ptr", fmt.Sprintf("%p", mach)).
+ Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)).
+ Str("olm_driver", olm.Driver).
+ Msg("Loaded olm account")
return nil
}
+func (mach *OlmMachine) Destroy() {
+ mach.Log.Debug().
+ Str("machine_ptr", fmt.Sprintf("%p", mach)).
+ Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)).
+ Msg("Destroying olm machine")
+ mach.cancelBackgroundCtx()
+ // TODO actually destroy something?
+}
+
func (mach *OlmMachine) saveAccount(ctx context.Context) error {
err := mach.CryptoStore.PutAccount(ctx, mach.account)
if err != nil {
@@ -184,7 +206,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error {
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
start := time.Now()
return func() {
- duration := time.Now().Sub(start)
+ duration := time.Since(start)
if duration > expectedDuration {
zerolog.Ctx(ctx).Warn().
Str("action", thing).
@@ -361,7 +383,7 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event)
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
if err != nil {
- mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
+ mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session")
}
}
@@ -581,7 +603,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
- log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
+ log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session")
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
@@ -708,7 +730,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
start := time.Now()
mach.otkUploadLock.Lock()
defer mach.otkUploadLock.Unlock()
- if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 {
+ if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) {
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count")
resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{})
if err != nil {
diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go
new file mode 100644
index 00000000..fd40d795
--- /dev/null
+++ b/crypto/machine_bench_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package crypto_test
+
+import (
+ "context"
+ "fmt"
+ "math/rand/v2"
+ "testing"
+
+ "github.com/rs/zerolog"
+ globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/crypto/cryptohelper"
+ "maunium.net/go/mautrix/id"
+ "maunium.net/go/mautrix/mockserver"
+)
+
+func randomDeviceCount(r *rand.Rand) int {
+ k := 1
+ for k < 10 && r.IntN(3) > 0 {
+ k++
+ }
+ return k
+}
+
+func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) {
+ globallog.Logger = zerolog.Nop()
+ server := mockserver.Create(b)
+ server.PopOTKs = false
+ server.MemoryStore = false
+ var i int
+ var shareTargets []id.UserID
+ r := rand.New(rand.NewPCG(293, 0))
+ var totalDeviceCount int
+ for i = 1; i < 1000; i++ {
+ userID := id.UserID(fmt.Sprintf("@user%d:localhost", i))
+ deviceCount := randomDeviceCount(r)
+ for j := 0; j < deviceCount; j++ {
+ client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j)))
+ mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
+ keysCache, err := mach.GenerateCrossSigningKeys()
+ require.NoError(b, err)
+ err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
+ require.NoError(b, err)
+ }
+ totalDeviceCount += deviceCount
+ shareTargets = append(shareTargets, userID)
+ }
+ for b.Loop() {
+ client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i)))
+ mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
+ keysCache, err := mach.GenerateCrossSigningKeys()
+ require.NoError(b, err)
+ err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
+ require.NoError(b, err)
+ err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets)
+ require.NoError(b, err)
+ i++
+ }
+ fmt.Println(totalDeviceCount, "devices total")
+}
diff --git a/crypto/machine_test.go b/crypto/machine_test.go
index 59c86236..872c3ac4 100644
--- a/crypto/machine_test.go
+++ b/crypto/machine_test.go
@@ -36,20 +36,15 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID,
func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
client, err := mautrix.NewClient("http://localhost", userID, "token")
- if err != nil {
- t.Fatalf("Error creating client: %v", err)
- }
+ require.NoError(t, err, "Error creating client")
client.DeviceID = "device1"
gobStore := NewMemoryStore(nil)
- if err != nil {
- t.Fatalf("Error creating Gob store: %v", err)
- }
+ require.NoError(t, err, "Error creating Gob store")
machine := NewOlmMachine(client, nil, gobStore, mockStateStore{})
- if err := machine.Load(context.TODO()); err != nil {
- t.Fatalf("Error creating account: %v", err)
- }
+ err = machine.Load(context.TODO())
+ require.NoError(t, err, "Error creating account")
return machine
}
@@ -82,9 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
// create outbound olm session for sending machine using OTK
olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)
- if err != nil {
- t.Errorf("Failed to create outbound olm session: %v", err)
- }
+ require.NoError(t, err, "Error creating outbound olm session")
// store sender device identity in receiving machine store
machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{
@@ -121,29 +114,21 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
Type: event.ToDeviceEncrypted,
Sender: "user1",
}, senderKey, content.Type, content.Body)
- if err != nil {
- t.Errorf("Error decrypting olm content: %v", err)
- }
+ require.NoError(t, err, "Error decrypting olm ciphertext")
+
// store room key in new inbound group session
roomKeyEvt := decrypted.Content.AsRoomKey()
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false)
- if err != nil {
- t.Errorf("Error creating inbound megolm session: %v", err)
- }
- if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil {
- t.Errorf("Error storing inbound megolm session: %v", err)
- }
+ require.NoError(t, err, "Error creating inbound group session")
+ err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs)
+ require.NoError(t, err, "Error storing inbound group session")
}
// encrypt event with megolm session in sending machine
eventContent := map[string]string{"hello": "world"}
encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
- if err != nil {
- t.Errorf("Error encrypting megolm event: %v", err)
- }
- if megolmOutSession.MessageCount != 1 {
- t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount)
- }
+ require.NoError(t, err, "Error encrypting megolm event")
+ assert.Equal(t, 1, megolmOutSession.MessageCount)
encryptedEvt := &event.Event{
Content: event.Content{Parsed: encryptedEvtContent},
@@ -155,22 +140,12 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
// decrypt event on receiving machine and confirm
decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt)
- if err != nil {
- t.Errorf("Error decrypting megolm event: %v", err)
- }
- if decryptedEvt.Type != event.EventMessage {
- t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type)
- }
- if decryptedEvt.Content.Raw["hello"] != "world" {
- t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw)
- }
+ require.NoError(t, err, "Error decrypting megolm event")
+ assert.Equal(t, event.EventMessage, decryptedEvt.Type)
+ assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"])
machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
- if megolmOutSession.Expired() {
- t.Error("Megolm outbound session expired before 3rd message")
- }
+ assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message")
machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
- if !megolmOutSession.Expired() {
- t.Error("Megolm outbound session not expired after 3rd message")
- }
+ assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message")
}
diff --git a/crypto/olm/account.go b/crypto/olm/account.go
index 68393e8a..2ec5dd70 100644
--- a/crypto/olm/account.go
+++ b/crypto/olm/account.go
@@ -87,6 +87,8 @@ type Account interface {
RemoveOneTimeKeys(s Session) error
}
+var Driver = "none"
+
var InitBlankAccount func() Account
var InitNewAccount func() (Account, error)
var InitNewAccountFromPickled func(pickled, key []byte) (Account, error)
diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go
index 957d7928..9e522b2a 100644
--- a/crypto/olm/errors.go
+++ b/crypto/olm/errors.go
@@ -10,50 +10,67 @@ import "errors"
// Those are the most common used errors
var (
- ErrBadSignature = errors.New("bad signature")
- ErrBadMAC = errors.New("bad mac")
- ErrBadMessageFormat = errors.New("bad message format")
- ErrBadVerification = errors.New("bad verification")
- ErrWrongProtocolVersion = errors.New("wrong protocol version")
- ErrEmptyInput = errors.New("empty input")
- ErrNoKeyProvided = errors.New("no key")
- ErrBadMessageKeyID = errors.New("bad message key id")
- ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
- ErrMsgIndexTooHigh = errors.New("message index too high")
- ErrProtocolViolation = errors.New("not protocol message order")
- ErrMessageKeyNotFound = errors.New("message key not found")
- ErrChainTooHigh = errors.New("chain index too high")
- ErrBadInput = errors.New("bad input")
- ErrBadVersion = errors.New("wrong version")
- ErrWrongPickleVersion = errors.New("wrong pickle version")
- ErrInputToSmall = errors.New("input too small (truncated?)")
- ErrOverflow = errors.New("overflow")
+ ErrBadSignature = errors.New("bad signature")
+ ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)")
+ ErrBadMessageFormat = errors.New("the message couldn't be decoded")
+ ErrBadVerification = errors.New("bad verification")
+ ErrWrongProtocolVersion = errors.New("wrong protocol version")
+ ErrEmptyInput = errors.New("empty input")
+ ErrNoKeyProvided = errors.New("no key provided")
+ ErrBadMessageKeyID = errors.New("the message references an unknown key ID")
+ ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
+ ErrMsgIndexTooHigh = errors.New("message index too high")
+ ErrProtocolViolation = errors.New("not protocol message order")
+ ErrMessageKeyNotFound = errors.New("message key not found")
+ ErrChainTooHigh = errors.New("chain index too high")
+ ErrBadInput = errors.New("bad input")
+ ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version")
+ ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version")
+ ErrInputToSmall = errors.New("input too small (truncated?)")
)
// Error codes from go-olm
var (
- EmptyInput = errors.New("empty input")
- NoKeyProvided = errors.New("no pickle key provided")
- NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
- SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// Error codes from olm code
var (
- NotEnoughRandom = errors.New("not enough entropy was supplied")
- OutputBufferTooSmall = errors.New("supplied output buffer is too small")
- BadMessageVersion = errors.New("the message version is unsupported")
- BadMessageFormat = errors.New("the message couldn't be decoded")
- BadMessageMAC = errors.New("the message couldn't be decrypted")
- BadMessageKeyID = errors.New("the message references an unknown key ID")
- InvalidBase64 = errors.New("the input base64 was invalid")
- BadAccountKey = errors.New("the supplied account key is invalid")
- UnknownPickleVersion = errors.New("the pickled object is too new")
- CorruptedPickle = errors.New("the pickled object couldn't be decoded")
- BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
- UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
- BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
- BadSignature = errors.New("received message had a bad signature")
- InputBufferTooSmall = errors.New("the input data was too small to be valid")
+ ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid")
+
+ ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied")
+ ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small")
+ ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid")
+ ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded")
+ ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
+ ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ EmptyInput = ErrEmptyInput
+ BadSignature = ErrBadSignature
+ InvalidBase64 = ErrLibolmInvalidBase64
+ BadMessageKeyID = ErrBadMessageKeyID
+ BadMessageFormat = ErrBadMessageFormat
+ BadMessageVersion = ErrWrongProtocolVersion
+ BadMessageMAC = ErrBadMAC
+ UnknownPickleVersion = ErrUnknownOlmPickleVersion
+ NotEnoughRandom = ErrLibolmNotEnoughRandom
+ OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall
+ BadAccountKey = ErrLibolmBadAccountKey
+ CorruptedPickle = ErrLibolmCorruptedPickle
+ BadSessionKey = ErrLibolmBadSessionKey
+ UnknownMessageIndex = ErrUnknownMessageIndex
+ BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle
+ InputBufferTooSmall = ErrInputToSmall
+ NoKeyProvided = ErrNoKeyProvided
+
+ NotEnoughGoRandom = ErrNotEnoughGoRandom
+ InputNotJSONString = ErrInputNotJSONString
+
+ ErrBadVersion = ErrUnknownJSONPickleVersion
+ ErrWrongPickleVersion = ErrUnknownJSONPickleVersion
+ ErrRatchetNotAvailable = ErrUnknownMessageIndex
)
diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go
index f5cecafc..6b5b65fd 100644
--- a/crypto/registergoolm.go
+++ b/crypto/registergoolm.go
@@ -2,4 +2,10 @@
package crypto
-import _ "maunium.net/go/mautrix/crypto/goolm"
+import (
+ "maunium.net/go/mautrix/crypto/goolm"
+)
+
+func init() {
+ goolm.Register()
+}
diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go
index ab388a5c..ef78b6b5 100644
--- a/crypto/registerlibolm.go
+++ b/crypto/registerlibolm.go
@@ -2,4 +2,8 @@
package crypto
-import _ "maunium.net/go/mautrix/crypto/libolm"
+import "maunium.net/go/mautrix/crypto/libolm"
+
+func init() {
+ libolm.Register()
+}
diff --git a/crypto/sessions.go b/crypto/sessions.go
index aecb0416..ccc7b784 100644
--- a/crypto/sessions.go
+++ b/crypto/sessions.go
@@ -18,8 +18,14 @@ import (
)
var (
- SessionNotShared = errors.New("session has not been shared")
- SessionExpired = errors.New("session has expired")
+ ErrSessionNotShared = errors.New("session has not been shared")
+ ErrSessionExpired = errors.New("session has expired")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ SessionNotShared = ErrSessionNotShared
+ SessionExpired = ErrSessionExpired
)
// OlmSessionList is a list of OlmSessions.
@@ -111,6 +117,7 @@ type InboundGroupSession struct {
MaxMessages int
IsScheduled bool
KeyBackupVersion id.KeyBackupVersion
+ KeySource id.KeySource
id id.SessionID
}
@@ -130,6 +137,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: isScheduled,
+ KeySource: id.KeySourceDirect,
}, nil
}
@@ -163,7 +171,7 @@ func (igs *InboundGroupSession) export() (*ExportedSession, error) {
ForwardingChains: igs.ForwardingChains,
RoomID: igs.RoomID,
SenderKey: igs.SenderKey,
- SenderClaimedKeys: SenderClaimedKeys{},
+ SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey},
SessionID: igs.ID(),
SessionKey: string(key),
}, nil
@@ -255,9 +263,9 @@ func (ogs *OutboundGroupSession) Expired() bool {
func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if !ogs.Shared {
- return nil, SessionNotShared
+ return nil, ErrSessionNotShared
} else if ogs.Expired() {
- return nil, SessionExpired
+ return nil, ErrSessionExpired
}
ogs.MessageCount++
ogs.LastEncryptedTime = time.Now()
diff --git a/crypto/sql_store.go b/crypto/sql_store.go
index b0625763..138cc557 100644
--- a/crypto/sql_store.go
+++ b/crypto/sql_store.go
@@ -251,8 +251,9 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender
}
// GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key.
+// This will exclude sessions that have never been used to encrypt or decrypt a message.
func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) {
- err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY created_at DESC LIMIT 1",
+ err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1",
key, store.AccountID).Scan(&createdAt)
if errors.Is(err, sql.ErrNoRows) {
err = nil
@@ -345,22 +346,23 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou
Int("max_messages", session.MaxMessages).
Bool("is_scheduled", session.IsScheduled).
Stringer("key_backup_version", session.KeyBackupVersion).
+ Stringer("key_source", session.KeySource).
Msg("Upserting megolm inbound group session")
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
- ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
- key_backup_version=excluded.key_backup_version
+ key_backup_version=excluded.key_backup_version, key_source=excluded.key_source
`,
session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages),
- session.IsScheduled, session.KeyBackupVersion, store.AccountID,
+ session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID,
)
return err
}
@@ -373,12 +375,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
+ var keySource id.KeySource
err := store.DB.QueryRow(ctx, `
- SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3`,
roomID, sessionID, store.AccountID,
- ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
+ ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
@@ -409,6 +412,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
+ KeySource: keySource,
}, nil
}
@@ -533,7 +537,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
- err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
+ var keySource id.KeySource
+ err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
if err != nil {
return nil, err
}
@@ -553,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
+ KeySource: keySource,
}, nil
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@@ -567,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
store.AccountID,
)
@@ -576,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
store.AccountID, version,
)
@@ -663,6 +669,20 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
+ if eventID == "" && timestamp == 0 {
+ var notOK bool
+ const validateEmptyQuery = `
+ SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3)
+ `
+ err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK)
+ if notOK {
+ zerolog.Ctx(ctx).Debug().
+ Uint("message_index", index).
+ Msg("Rejecting event without event ID and timestamp due to already knowing them")
+ }
+ return !notOK, err
+ }
+
const validateQuery = `
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql
index 00dd1387..3709f1e5 100644
--- a/crypto/sql_store_upgrade/00-latest-revision.sql
+++ b/crypto/sql_store_upgrade/00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v17 (compatible with v15+): Latest revision
+-- v0 -> v19 (compatible with v15+): Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@@ -71,8 +71,11 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
key_backup_version TEXT NOT NULL DEFAULT '',
+ key_source TEXT NOT NULL DEFAULT '',
PRIMARY KEY (account_id, session_id)
);
+-- Useful index to find keys that need backing up
+CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL;
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
account_id TEXT,
diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql
new file mode 100644
index 00000000..da26da0f
--- /dev/null
+++ b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql
@@ -0,0 +1,2 @@
+-- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster
+CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL;
diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql
new file mode 100644
index 00000000..f624222f
--- /dev/null
+++ b/crypto/sql_store_upgrade/19-megolm-session-source.sql
@@ -0,0 +1,2 @@
+-- v19 (compatible with v15+): Store megolm session source
+ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT '';
diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go
index e30925d9..8691d032 100644
--- a/crypto/ssss/client.go
+++ b/crypto/ssss/client.go
@@ -95,6 +95,22 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even
return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
}
+// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it,
+// alongside the unencrypted metadata, on the server.
+func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error {
+ if len(keys) == 0 {
+ return ErrNoKeyGiven
+ }
+ encrypted := make(map[string]EncryptedKeyData, len(keys))
+ for _, key := range keys {
+ encrypted[key.ID] = key.Encrypt(eventType.Type, data)
+ }
+ return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{
+ Encrypted: encrypted,
+ Metadata: metadata,
+ })
+}
+
// GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server.
func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) {
key, err = NewKey(passphrase)
diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go
index aa22360a..78ebd8f3 100644
--- a/crypto/ssss/key.go
+++ b/crypto/ssss/key.go
@@ -7,6 +7,8 @@
package ssss
import (
+ "crypto/hmac"
+ "crypto/sha256"
"encoding/base64"
"fmt"
"strings"
@@ -57,12 +59,12 @@ func NewKey(passphrase string) (*Key, error) {
// We store a certain hash in the key metadata so that clients can check if the user entered the correct key.
ivBytes := random.Bytes(utils.AESCTRIVLength)
keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes)
- var err error
- keyData.MAC, err = keyData.calculateHash(ssssKey)
+ macBytes, err := keyData.calculateHash(ssssKey)
if err != nil {
// This should never happen because we just generated the IV and key.
return nil, fmt.Errorf("failed to calculate hash: %w", err)
}
+ keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes)
return &Key{
Key: ssssKey,
@@ -108,12 +110,18 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error)
return nil, err
}
+ mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "="))
+ if err != nil {
+ return nil, err
+ }
+
// derive the AES and HMAC keys for the requested event type using the SSSS key
aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType)
// compare the stored MAC with the one we calculated from the ciphertext
- calcMac := utils.HMACSHA256B64(payload, hmacKey)
- if strings.TrimRight(data.MAC, "=") != calcMac {
+ h := hmac.New(sha256.New, hmacKey[:])
+ h.Write(payload)
+ if !hmac.Equal(h.Sum(nil), mac) {
return nil, ErrKeyDataMACMismatch
}
diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go
index 474c85d8..34775fa7 100644
--- a/crypto/ssss/meta.go
+++ b/crypto/ssss/meta.go
@@ -7,7 +7,10 @@
package ssss
import (
+ "crypto/hmac"
+ "crypto/sha256"
"encoding/base64"
+ "errors"
"fmt"
"strings"
@@ -33,7 +36,9 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error)
ssssKey, err := kd.Passphrase.GetKey(passphrase)
if err != nil {
return nil, err
- } else if err = kd.verifyKey(ssssKey); err != nil {
+ }
+ err = kd.verifyKey(ssssKey)
+ if err != nil && !errors.Is(err, ErrUnverifiableKey) {
return nil, err
}
@@ -49,7 +54,9 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error
ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey)
if ssssKey == nil {
return nil, ErrInvalidRecoveryKey
- } else if err := kd.verifyKey(ssssKey); err != nil {
+ }
+ err := kd.verifyKey(ssssKey)
+ if err != nil && !errors.Is(err, ErrUnverifiableKey) {
return nil, err
}
@@ -57,20 +64,28 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error
ID: keyID,
Key: ssssKey,
Metadata: kd,
- }, nil
+ }, err
}
func (kd *KeyMetadata) verifyKey(key []byte) error {
+ if kd.MAC == "" || kd.IV == "" {
+ return ErrUnverifiableKey
+ }
unpaddedMAC := strings.TrimRight(kd.MAC, "=")
expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength)
if len(unpaddedMAC) != expectedMACLength {
return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength)
}
- hash, err := kd.calculateHash(key)
+ expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC)
+ if err != nil {
+ return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err)
+ }
+ calculatedMAC, err := kd.calculateHash(key)
if err != nil {
return err
}
- if unpaddedMAC != hash {
+ // This doesn't really need to be constant time since it's fully local, but might as well be.
+ if !hmac.Equal(expectedMAC, calculatedMAC) {
return ErrIncorrectSSSSKey
}
return nil
@@ -83,23 +98,26 @@ func (kd *KeyMetadata) VerifyKey(key []byte) bool {
// calculateHash calculates the hash used for checking if the key is entered correctly as described
// in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2
-func (kd *KeyMetadata) calculateHash(key []byte) (string, error) {
+func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) {
aesKey, hmacKey := utils.DeriveKeysSHA256(key, "")
unpaddedIV := strings.TrimRight(kd.IV, "=")
expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength)
- if len(unpaddedIV) != expectedIVLength {
- return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength)
+ if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 {
+ return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength)
}
-
- var ivBytes [utils.AESCTRIVLength]byte
- _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV))
+ rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV)
if err != nil {
- return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err)
+ return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err)
}
+ // TODO log a warning for non-16 byte IVs?
+ // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used.
+ ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength])
- cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes)
-
- return utils.HMACSHA256B64(cipher, hmacKey), nil
+ zeroes := make([]byte, utils.AESCTRKeyLength)
+ encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes)
+ h := hmac.New(sha256.New, hmacKey[:])
+ h.Write(encryptedZeroes)
+ return h.Sum(nil), nil
}
// PassphraseMetadata represents server-side metadata about a SSSS key passphrase.
diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go
index 4f2ff378..d59809c7 100644
--- a/crypto/ssss/meta_test.go
+++ b/crypto/ssss/meta_test.go
@@ -8,10 +8,10 @@ package ssss_test
import (
"encoding/json"
- "errors"
"testing"
"github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/crypto/ssss"
)
@@ -41,10 +41,24 @@ const key2Meta = `
}
`
+const key2MetaUnverified = `
+{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2"
+}
+`
+
+const key2MetaLongIV = `
+{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
+}
+`
+
const key2MetaBrokenIV = `
{
"algorithm": "m.secret_storage.v1.aes-hmac-sha2",
- "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow",
+ "iv": "MeowMeowMeow",
"mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
}
`
@@ -70,23 +84,11 @@ func getKeyMeta(meta string) *ssss.KeyMetadata {
}
func getKey1() *ssss.Key {
- km := getKeyMeta(key1Meta)
- key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey)
- if err != nil {
- panic(err)
- }
- key.ID = key1ID
- return key
+ return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey))
}
func getKey2() *ssss.Key {
- km := getKeyMeta(key2Meta)
- key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- if err != nil {
- panic(err)
- }
- key.ID = key2ID
- return key
+ return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey))
}
func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) {
@@ -105,17 +107,33 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) {
assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
}
+func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) {
+ km := getKeyMeta(key2MetaLongIV)
+ key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
+ assert.NoError(t, err)
+ assert.NotNil(t, key)
+ assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
+}
+
+func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) {
+ km := getKeyMeta(key2MetaUnverified)
+ key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
+ assert.ErrorIs(t, err, ssss.ErrUnverifiableKey)
+ assert.NotNil(t, key)
+ assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
+}
+
func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyRecoveryKey(key1ID, "foo")
- assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err)
+ assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err)
+ assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey)
assert.Nil(t, key)
}
@@ -130,27 +148,27 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) {
func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple")
- assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) {
km := getKeyMeta(key2Meta)
key, err := km.VerifyPassphrase(key2ID, "hmm")
- assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrNoPassphrase)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) {
km := getKeyMeta(key2MetaBrokenIV)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) {
km := getKeyMeta(key2MetaBrokenMAC)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata)
assert.Nil(t, key)
}
diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go
index 345393b0..b7465d3e 100644
--- a/crypto/ssss/types.go
+++ b/crypto/ssss/types.go
@@ -26,7 +26,8 @@ var (
ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm")
ErrIncorrectSSSSKey = errors.New("incorrect SSSS key")
ErrInvalidRecoveryKey = errors.New("invalid recovery key")
- ErrCorruptedKeyMetadata = errors.New("corrupted key metadata")
+ ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata")
+ ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata")
)
// Algorithm is the identifier for an SSSS encryption algorithm.
@@ -57,6 +58,7 @@ type EncryptedKeyData struct {
type EncryptedAccountDataEventContent struct {
Encrypted map[string]EncryptedKeyData `json:"encrypted"`
+ Metadata map[string]any `json:"com.beeper.metadata,omitzero"`
}
func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) {
diff --git a/crypto/store.go b/crypto/store.go
index 8b7c0a96..7620cf35 100644
--- a/crypto/store.go
+++ b/crypto/store.go
@@ -525,6 +525,9 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
}
val, ok := gs.MessageIndices[key]
if !ok {
+ if eventID == "" && timestamp == 0 {
+ return true, nil
+ }
gs.MessageIndices[key] = messageIndexValue{
EventID: eventID,
Timestamp: timestamp,
diff --git a/crypto/store_test.go b/crypto/store_test.go
index a7c4d75a..7a47243e 100644
--- a/crypto/store_test.go
+++ b/crypto/store_test.go
@@ -13,6 +13,7 @@ import (
"testing"
_ "github.com/mattn/go-sqlite3"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
@@ -29,22 +30,14 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4
func getCryptoStores(t *testing.T) map[string]Store {
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
- if err != nil {
- t.Fatalf("Error opening db: %v", err)
- }
+ require.NoError(t, err, "Error opening raw database")
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
- if err != nil {
- t.Fatalf("Error opening db: %v", err)
- }
+ require.NoError(t, err, "Error creating database wrapper")
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
- if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
- t.Fatalf("Error creating tables: %v", err)
- }
+ err = sqlStore.DB.Upgrade(context.TODO())
+ require.NoError(t, err, "Error upgrading database")
gobStore := NewMemoryStore(nil)
- if err != nil {
- t.Fatalf("Error creating Gob store: %v", err)
- }
return map[string]Store{
"sql": sqlStore,
@@ -56,9 +49,10 @@ func TestPutNextBatch(t *testing.T) {
stores := getCryptoStores(t)
store := stores["sql"].(*SQLCryptoStore)
store.PutNextBatch(context.Background(), "batch1")
- if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" {
- t.Errorf("Expected batch1, got %v", batch)
- }
+
+ batch, err := store.GetNextBatch(context.Background())
+ require.NoError(t, err, "Error retrieving next batch")
+ assert.Equal(t, "batch1", batch)
}
func TestPutAccount(t *testing.T) {
@@ -68,15 +62,9 @@ func TestPutAccount(t *testing.T) {
acc := NewOlmAccount()
store.PutAccount(context.TODO(), acc)
retrieved, err := store.GetAccount(context.TODO())
- if err != nil {
- t.Fatalf("Error retrieving account: %v", err)
- }
- if acc.IdentityKey() != retrieved.IdentityKey() {
- t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey())
- }
- if acc.SigningKey() != retrieved.SigningKey() {
- t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey())
- }
+ require.NoError(t, err, "Error retrieving account")
+ assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match")
+ assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match")
})
}
}
@@ -86,18 +74,36 @@ func TestValidateMessageIndex(t *testing.T) {
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
acc := NewOlmAccount()
- if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok {
- t.Error("First message not validated successfully")
- }
- if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok {
- t.Error("First message validated successfully after changing timestamp")
- }
- if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok {
- t.Error("First message validated successfully after changing event ID")
- }
- if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok {
- t.Error("First message not validated successfully for a second time")
- }
+
+ // Validating without event ID and timestamp before we have them should work
+ ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0)
+ require.NoError(t, err, "Error validating message index")
+ assert.True(t, ok, "First message validation should be valid")
+
+ // First message should validate successfully
+ ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000)
+ require.NoError(t, err, "Error validating message index")
+ assert.True(t, ok, "First message validation should be valid")
+
+ // Edit the timestamp and ensure validate fails
+ ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001)
+ require.NoError(t, err, "Error validating message index after timestamp change")
+ assert.False(t, ok, "First message validation should fail after timestamp change")
+
+ // Edit the event ID and ensure validate fails
+ ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000)
+ require.NoError(t, err, "Error validating message index after event ID change")
+ assert.False(t, ok, "First message validation should fail after event ID change")
+
+ // Validate again with the original parameters and ensure that it still passes
+ ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000)
+ require.NoError(t, err, "Error validating message index")
+ assert.True(t, ok, "First message validation should be valid")
+
+ // Validating without event ID and timestamp must fail if we already know them
+ ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0)
+ require.NoError(t, err, "Error validating message index")
+ assert.False(t, ok, "First message validation should be invalid")
})
}
}
@@ -106,43 +112,26 @@ func TestStoreOlmSession(t *testing.T) {
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
- if store.HasSession(context.TODO(), olmSessID) {
- t.Error("Found Olm session before inserting it")
- }
+ require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it")
+
olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test"))
- if err != nil {
- t.Fatalf("Error creating internal Olm session: %v", err)
- }
+ require.NoError(t, err, "Error creating internal Olm session")
olmSess := OlmSession{
id: olmSessID,
Internal: olmInternal,
}
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
- if err != nil {
- t.Errorf("Error storing Olm session: %v", err)
- }
- if !store.HasSession(context.TODO(), olmSessID) {
- t.Error("Not found Olm session after inserting it")
- }
+ require.NoError(t, err, "Error storing Olm session")
+ assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it")
retrieved, err := store.GetLatestSession(context.TODO(), olmSessID)
- if err != nil {
- t.Errorf("Failed retrieving Olm session: %v", err)
- }
-
- if retrieved.ID() != olmSessID {
- t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID())
- }
+ require.NoError(t, err, "Error retrieving Olm session")
+ assert.EqualValues(t, olmSessID, retrieved.ID())
pickled, err := retrieved.Internal.Pickle([]byte("test"))
- if err != nil {
- t.Fatalf("Error pickling Olm session: %v", err)
- }
-
- if string(pickled) != olmPickled {
- t.Error("Pickled Olm session does not match original")
- }
+ require.NoError(t, err, "Error pickling Olm session")
+ assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original")
})
}
}
@@ -154,9 +143,7 @@ func TestStoreMegolmSession(t *testing.T) {
acc := NewOlmAccount()
internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test"))
- if err != nil {
- t.Fatalf("Error creating internal inbound group session: %v", err)
- }
+ require.NoError(t, err, "Error creating internal inbound group session")
igs := &InboundGroupSession{
Internal: internal,
@@ -166,20 +153,14 @@ func TestStoreMegolmSession(t *testing.T) {
}
err = store.PutGroupSession(context.TODO(), igs)
- if err != nil {
- t.Errorf("Error storing inbound group session: %v", err)
- }
+ require.NoError(t, err, "Error storing inbound group session")
retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID())
- if err != nil {
- t.Errorf("Error retrieving inbound group session: %v", err)
- }
+ require.NoError(t, err, "Error retrieving inbound group session")
- if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil {
- t.Fatalf("Error pickling inbound group session: %v", err)
- } else if string(pickled) != groupSession {
- t.Error("Pickled inbound group session does not match original")
- }
+ pickled, err := retrieved.Internal.Pickle([]byte("test"))
+ require.NoError(t, err, "Error pickling inbound group session")
+ assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original")
})
}
}
@@ -189,40 +170,24 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
sess, err := store.GetOutboundGroupSession(context.TODO(), "room1")
- if sess != nil {
- t.Error("Got outbound session before inserting")
- }
- if err != nil {
- t.Errorf("Error retrieving outbound session: %v", err)
- }
+ require.NoError(t, err, "Error retrieving outbound session")
+ require.Nil(t, sess, "Got outbound session before inserting")
outbound, err := NewOutboundGroupSession("room1", nil)
require.NoError(t, err)
err = store.AddOutboundGroupSession(context.TODO(), outbound)
- if err != nil {
- t.Errorf("Error inserting outbound session: %v", err)
- }
+ require.NoError(t, err, "Error inserting outbound session")
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
- if sess == nil {
- t.Error("Did not get outbound session after inserting")
- }
- if err != nil {
- t.Errorf("Error retrieving outbound session: %v", err)
- }
+ require.NoError(t, err, "Error retrieving outbound session")
+ assert.NotNil(t, sess, "Did not get outbound session after inserting")
err = store.RemoveOutboundGroupSession(context.TODO(), "room1")
- if err != nil {
- t.Errorf("Error deleting outbound session: %v", err)
- }
+ require.NoError(t, err, "Error deleting outbound session")
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
- if sess != nil {
- t.Error("Got outbound session after deleting")
- }
- if err != nil {
- t.Errorf("Error retrieving outbound session: %v", err)
- }
+ require.NoError(t, err, "Error retrieving outbound session after deletion")
+ assert.Nil(t, sess, "Got outbound session after deleting")
})
}
}
@@ -244,58 +209,41 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) {
t.Run(storeName, func(t *testing.T) {
device := resetDevice()
err := store.PutDevice(context.TODO(), "user1", device)
- if err != nil {
- t.Errorf("Error storing devices: %v", err)
- }
+ require.NoError(t, err, "Error storing device")
shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- if err != nil {
- t.Errorf("Error checking if outbound group session is shared: %v", err)
- } else if shared {
- t.Errorf("Outbound group session shared when it shouldn't")
- }
+ require.NoError(t, err, "Error checking if outbound group session is shared")
+ assert.False(t, shared, "Outbound group session should not be shared initially")
err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- if err != nil {
- t.Errorf("Error marking outbound group session as shared: %v", err)
- }
+ require.NoError(t, err, "Error marking outbound group session as shared")
shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- if err != nil {
- t.Errorf("Error checking if outbound group session is shared: %v", err)
- } else if !shared {
- t.Errorf("Outbound group session not shared when it should")
- }
+ require.NoError(t, err, "Error checking if outbound group session is shared")
+ assert.True(t, shared, "Outbound group session should be shared after marking it as such")
device = resetDevice()
err = store.PutDevice(context.TODO(), "user1", device)
- if err != nil {
- t.Errorf("Error storing devices: %v", err)
- }
+ require.NoError(t, err, "Error storing device after resetting")
shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- if err != nil {
- t.Errorf("Error checking if outbound group session is shared: %v", err)
- } else if shared {
- t.Errorf("Outbound group session shared when it shouldn't")
- }
+ require.NoError(t, err, "Error checking if outbound group session is shared")
+ assert.False(t, shared, "Outbound group session should not be shared after resetting device")
})
}
}
func TestStoreDevices(t *testing.T) {
+ devicesToCreate := 17
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
outdated, err := store.GetOutdatedTrackedUsers(context.TODO())
- if err != nil {
- t.Errorf("Error filtering tracked users: %v", err)
- }
- if len(outdated) > 0 {
- t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
- }
+ require.NoError(t, err, "Error filtering tracked users")
+ assert.Empty(t, outdated, "Expected no outdated tracked users initially")
+
deviceMap := make(map[id.DeviceID]*id.Device)
- for i := 0; i < 17; i++ {
+ for i := 0; i < devicesToCreate; i++ {
iStr := strconv.Itoa(i)
acc := NewOlmAccount()
deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{
@@ -306,59 +254,33 @@ func TestStoreDevices(t *testing.T) {
}
}
err = store.PutDevices(context.TODO(), "user1", deviceMap)
- if err != nil {
- t.Errorf("Error storing devices: %v", err)
- }
+ require.NoError(t, err, "Error storing devices")
devs, err := store.GetDevices(context.TODO(), "user1")
- if err != nil {
- t.Errorf("Error getting devices: %v", err)
- }
- if len(devs) != 17 {
- t.Errorf("Stored 17 devices, got back %v", len(devs))
- }
- if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey {
- t.Errorf("First device identity key does not match")
- }
- if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey {
- t.Errorf("Last device identity key does not match")
- }
+ require.NoError(t, err, "Error getting devices")
+ assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate)
+ assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices")
filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"})
- if err != nil {
- t.Errorf("Error filtering tracked users: %v", err)
- } else if len(filtered) != 1 || filtered[0] != "user1" {
- t.Errorf("Expected to get 'user1' from filter, got %v", filtered)
- }
+ require.NoError(t, err, "Error filtering tracked users")
+ assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter")
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
- if err != nil {
- t.Errorf("Error filtering tracked users: %v", err)
- }
- if len(outdated) > 0 {
- t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
- }
+ require.NoError(t, err, "Error filtering tracked users")
+ assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage")
+
err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"})
- if err != nil {
- t.Errorf("Error marking tracked users outdated: %v", err)
- }
+ require.NoError(t, err, "Error marking tracked users outdated")
+
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
- if err != nil {
- t.Errorf("Error filtering tracked users: %v", err)
- }
- if len(outdated) != 1 || outdated[0] != id.UserID("user1") {
- t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated)
- }
+ require.NoError(t, err, "Error filtering tracked users")
+ assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated")
+
err = store.PutDevices(context.TODO(), "user1", deviceMap)
- if err != nil {
- t.Errorf("Error storing devices: %v", err)
- }
+ require.NoError(t, err, "Error storing devices again")
+
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
- if err != nil {
- t.Errorf("Error filtering tracked users: %v", err)
- }
- if len(outdated) > 0 {
- t.Errorf("Got outdated tracked users %v when expected none", outdated)
- }
+ require.NoError(t, err, "Error filtering tracked users")
+ assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices")
})
}
}
@@ -369,16 +291,11 @@ func TestStoreSecrets(t *testing.T) {
t.Run(storeName, func(t *testing.T) {
storedSecret := "trustno1"
err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret)
- if err != nil {
- t.Errorf("Error storing secret: %v", err)
- }
+ require.NoError(t, err, "Error storing secret")
secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1)
- if err != nil {
- t.Errorf("Error storing secret: %v", err)
- } else if secret != storedSecret {
- t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret)
- }
+ require.NoError(t, err, "Error retrieving secret")
+ assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret")
})
}
}
diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go
index c4f01a68..b12fd9e2 100644
--- a/crypto/utils/utils_test.go
+++ b/crypto/utils/utils_test.go
@@ -9,6 +9,9 @@ package utils
import (
"encoding/base64"
"testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestAES256Ctr(t *testing.T) {
@@ -16,9 +19,7 @@ func TestAES256Ctr(t *testing.T) {
key, iv := GenAttachmentA256CTR()
enc := XorA256CTR([]byte(expected), key, iv)
dec := XorA256CTR(enc, key, iv)
- if string(dec) != expected {
- t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec))
- }
+ assert.EqualValues(t, expected, dec, "Decrypted text should match original")
var key2 [AESCTRKeyLength]byte
var iv2 [AESCTRIVLength]byte
@@ -29,9 +30,7 @@ func TestAES256Ctr(t *testing.T) {
iv2[i] = byte(i) + 32
}
dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2)
- if string(dec2) != expected {
- t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2))
- }
+ assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original")
}
func TestPBKDF(t *testing.T) {
@@ -42,9 +41,7 @@ func TestPBKDF(t *testing.T) {
key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256)
expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E="
keyB64 := base64.StdEncoding.EncodeToString([]byte(key))
- if keyB64 != expected {
- t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64)
- }
+ assert.Equal(t, expected, keyB64)
}
func TestDecodeSSSSKey(t *testing.T) {
@@ -53,13 +50,10 @@ func TestDecodeSSSSKey(t *testing.T) {
expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw="
decodedB64 := base64.StdEncoding.EncodeToString(decoded[:])
- if expected != decodedB64 {
- t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64)
- }
+ assert.Equal(t, expected, decodedB64)
- if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey {
- t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded)
- }
+ encoded := EncodeBase58RecoveryKey(decoded)
+ assert.Equal(t, recoveryKey, encoded)
}
func TestKeyDerivationAndHMAC(t *testing.T) {
@@ -69,15 +63,11 @@ func TestKeyDerivationAndHMAC(t *testing.T) {
aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master")
ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=")
- if err != nil {
- t.Error(err)
- }
+ require.NoError(t, err)
calcMac := HMACSHA256B64(ciphertextBytes, hmacKey)
expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E"
- if calcMac != expectedMac {
- t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac)
- }
+ assert.Equal(t, expectedMac, calcMac)
var ivBytes [AESCTRIVLength]byte
decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==")
@@ -85,7 +75,5 @@ func TestKeyDerivationAndHMAC(t *testing.T) {
decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes))
expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s="
- if expectedDec != decrypted {
- t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted)
- }
+ assert.Equal(t, expectedDec, decrypted)
}
diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go
deleted file mode 100644
index b6bf3d2c..00000000
--- a/crypto/verificationhelper/mockserver_test.go
+++ /dev/null
@@ -1,255 +0,0 @@
-// Copyright (c) 2024 Sumner Evans
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package verificationhelper_test
-
-import (
- "context"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "testing"
-
- "github.com/gorilla/mux"
- "github.com/rs/zerolog/log" // zerolog-allow-global-log
- "github.com/stretchr/testify/require"
- "go.mau.fi/util/random"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/crypto/cryptohelper"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow
-// testing of the interactive verification process.
-type mockServer struct {
- *httptest.Server
-
- AccessTokenToUserID map[string]id.UserID
- DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
- AccountData map[id.UserID]map[event.Type]json.RawMessage
- DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
- MasterKeys map[id.UserID]mautrix.CrossSigningKeys
- SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
- UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
-}
-
-func DecodeVarsMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
- var err error
- for k, v := range vars {
- vars[k], err = url.PathUnescape(v)
- if err != nil {
- panic(err)
- }
- }
- next.ServeHTTP(w, r)
- })
-}
-
-func createMockServer(t *testing.T) *mockServer {
- t.Helper()
-
- server := mockServer{
- AccessTokenToUserID: map[string]id.UserID{},
- DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
- AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
- DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
- MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- }
-
- router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath()
- router.Use(DecodeVarsMiddleware)
- router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost)
- router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost)
- router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut)
- router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut)
- router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost)
- router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost)
- router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost)
-
- server.Server = httptest.NewServer(router)
- return &server
-}
-
-func (ms *mockServer) getUserID(r *http.Request) id.UserID {
- authHeader := r.Header.Get("Authorization")
- authHeader = strings.TrimPrefix(authHeader, "Bearer ")
- userID, ok := ms.AccessTokenToUserID[authHeader]
- if !ok {
- panic("no user ID found for access token " + authHeader)
- }
- return userID
-}
-
-func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
- w.Write([]byte("{}"))
-}
-
-func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
- var loginReq mautrix.ReqLogin
- json.NewDecoder(r.Body).Decode(&loginReq)
-
- deviceID := loginReq.DeviceID
- if deviceID == "" {
- deviceID = id.DeviceID(random.String(10))
- }
-
- accessToken := random.String(30)
- userID := id.UserID(loginReq.Identifier.User)
- s.AccessTokenToUserID[accessToken] = userID
-
- json.NewEncoder(w).Encode(&mautrix.RespLogin{
- AccessToken: accessToken,
- DeviceID: deviceID,
- UserID: userID,
- })
-}
-
-func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
- var req mautrix.ReqSendToDevice
- json.NewDecoder(r.Body).Decode(&req)
- evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType}
-
- for user, devices := range req.Messages {
- for device, content := range devices {
- if _, ok := s.DeviceInbox[user]; !ok {
- s.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
- }
- content.ParseRaw(evtType)
- s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{
- Sender: s.getUserID(r),
- Type: evtType,
- Content: *content,
- })
- }
- }
- s.emptyResp(w, r)
-}
-
-func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
- userID := id.UserID(vars["userID"])
- eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType}
-
- jsonData, _ := io.ReadAll(r.Body)
- if _, ok := s.AccountData[userID]; !ok {
- s.AccountData[userID] = map[event.Type]json.RawMessage{}
- }
- s.AccountData[userID][eventType] = json.RawMessage(jsonData)
- s.emptyResp(w, r)
-}
-
-func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
- var req mautrix.ReqQueryKeys
- json.NewDecoder(r.Body).Decode(&req)
- resp := mautrix.RespQueryKeys{
- MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
- }
- for user := range req.DeviceKeys {
- resp.MasterKeys[user] = s.MasterKeys[user]
- resp.UserSigningKeys[user] = s.UserSigningKeys[user]
- resp.SelfSigningKeys[user] = s.SelfSigningKeys[user]
- resp.DeviceKeys[user] = s.DeviceKeys[user]
- }
- json.NewEncoder(w).Encode(&resp)
-}
-
-func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
- var req mautrix.ReqUploadKeys
- json.NewDecoder(r.Body).Decode(&req)
-
- userID := s.getUserID(r)
- if _, ok := s.DeviceKeys[userID]; !ok {
- s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
- }
- s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys
-
- json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{
- OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50},
- })
-}
-
-func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
- var req mautrix.UploadCrossSigningKeysReq
- json.NewDecoder(r.Body).Decode(&req)
-
- userID := s.getUserID(r)
- s.MasterKeys[userID] = req.Master
- s.SelfSigningKeys[userID] = req.SelfSigning
- s.UserSigningKeys[userID] = req.UserSigning
-
- s.emptyResp(w, r)
-}
-
-func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
- t.Helper()
- client, err := mautrix.NewClient(ms.URL, "", "")
- require.NoError(t, err)
- client.StateStore = mautrix.NewMemoryStateStore()
-
- _, err = client.Login(ctx, &mautrix.ReqLogin{
- Type: mautrix.AuthTypePassword,
- Identifier: mautrix.UserIdentifier{
- Type: mautrix.IdentifierTypeUser,
- User: userID.String(),
- },
- DeviceID: deviceID,
- Password: "password",
- StoreCredentials: true,
- })
- require.NoError(t, err)
-
- cryptoStore := crypto.NewMemoryStore(nil)
- cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore)
- require.NoError(t, err)
- client.Crypto = cryptoHelper
-
- err = cryptoHelper.Init(ctx)
- require.NoError(t, err)
-
- machineLog := log.Logger.With().
- Stringer("my_user_id", userID).
- Stringer("my_device_id", deviceID).
- Logger()
- cryptoHelper.Machine().Log = &machineLog
-
- err = cryptoHelper.Machine().ShareKeys(ctx, 50)
- require.NoError(t, err)
-
- return client, cryptoStore
-}
-
-func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) {
- t.Helper()
-
- for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
- client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
- ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
- }
-}
-
-func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) {
- err := cryptoStore.PutDevice(ctx, userID, &id.Device{
- UserID: userID,
- DeviceID: deviceID,
- })
- if err != nil {
- panic(err)
- }
-}
diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go
index 1313a613..e6392c79 100644
--- a/crypto/verificationhelper/sas.go
+++ b/crypto/verificationhelper/sas.go
@@ -695,7 +695,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific
// Verify the MAC for each key
var theirDevice *id.Device
for keyID, mac := range macEvt.MAC {
- log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key")
+ log.Info().Stringer("key_id", keyID).Msg("Received MAC for key")
alg, kID := keyID.Parse()
if alg != id.KeyAlgorithmEd25519 {
diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go
index 9d843ea8..0a781c16 100644
--- a/crypto/verificationhelper/verificationhelper.go
+++ b/crypto/verificationhelper/verificationhelper.go
@@ -848,7 +848,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif
// here, since the start command for scanning and showing QR codes
// should be of type m.reciprocate.v1.
log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event")
- vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", txn.StartEventContent.Method))
+ vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unknown method %s", txn.StartEventContent.Method)
}
}
diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go
index aace2230..5e3f146b 100644
--- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go
+++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go
@@ -32,7 +32,6 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -51,10 +50,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, bobUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@@ -83,7 +82,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device detected that its QR code
// was scanned.
@@ -98,7 +97,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = sendingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
} else { // receiving scans QR
// Emulate scanning the QR code shown by the sending device on
// the receiving device.
@@ -121,7 +120,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device detected that its QR code was
// scanned.
@@ -136,7 +135,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = receivingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
}
// Ensure that both devices have marked the verification as done.
diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go
index 937cc414..ea918cd4 100644
--- a/crypto/verificationhelper/verificationhelper_qr_self_test.go
+++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go
@@ -36,7 +36,6 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -62,7 +61,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
if tc.expectedAcceptError != "" {
@@ -72,7 +71,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
require.NoError(t, err)
}
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@@ -135,7 +134,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -152,10 +150,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@@ -184,7 +182,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device detected that its QR code
// was scanned.
@@ -199,7 +197,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = sendingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
} else { // receiving scans QR
// Emulate scanning the QR code shown by the sending device on
// the receiving device.
@@ -222,7 +220,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device detected that its QR code was
// scanned.
@@ -237,7 +235,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = receivingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
}
// Ensure that both devices have marked the verification as done.
@@ -251,7 +249,6 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -263,10 +260,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes()
sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes()
@@ -310,7 +307,6 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -327,10 +323,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes()
sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes()
@@ -348,7 +344,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// Ensure that the receiving device received a cancellation.
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 1)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
cancellation := receivingCallbacks.GetVerificationCancellation(txnID)
require.NotNil(t, cancellation)
assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code)
@@ -362,7 +358,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// Ensure that the sending device received a cancellation.
sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID]
assert.Len(t, sendingInbox, 1)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
cancellation := sendingCallbacks.GetVerificationCancellation(txnID)
require.NotNil(t, cancellation)
assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code)
diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go
index 5747ac34..283eca84 100644
--- a/crypto/verificationhelper/verificationhelper_sas_test.go
+++ b/crypto/verificationhelper/verificationhelper_sas_test.go
@@ -36,7 +36,6 @@ func TestVerification_SAS(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -60,10 +59,10 @@ func TestVerification_SAS(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Test that the start event is correct
var startEvt *event.VerificationStartEventContent
@@ -102,7 +101,7 @@ func TestVerification_SAS(t *testing.T) {
if tc.sendingStartsSAS {
// Process the verification start event on the receiving
// device.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Receiving device sent the accept event to the sending device
sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID]
@@ -110,7 +109,7 @@ func TestVerification_SAS(t *testing.T) {
acceptEvt = sendingInbox[0].Content.AsVerificationAccept()
} else {
// Process the verification start event on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Sending device sent the accept event to the receiving device
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
@@ -129,7 +128,7 @@ func TestVerification_SAS(t *testing.T) {
var firstKeyEvt *event.VerificationKeyEventContent
if tc.sendingStartsSAS {
// Process the verification accept event on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Sending device sends first key event to the receiving
// device.
@@ -139,7 +138,7 @@ func TestVerification_SAS(t *testing.T) {
} else {
// Process the verification accept event on the receiving
// device.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Receiving device sends first key event to the sending
// device.
@@ -155,7 +154,7 @@ func TestVerification_SAS(t *testing.T) {
var secondKeyEvt *event.VerificationKeyEventContent
if tc.sendingStartsSAS {
// Process the first key event on the receiving device.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Receiving device sends second key event to the sending
// device.
@@ -170,7 +169,7 @@ func TestVerification_SAS(t *testing.T) {
assert.Len(t, descriptions, 7)
} else {
// Process the first key event on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Sending device sends second key event to the receiving
// device.
@@ -191,10 +190,10 @@ func TestVerification_SAS(t *testing.T) {
// Ensure that the SAS codes are the same.
if tc.sendingStartsSAS {
// Process the second key event on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
} else {
// Process the second key event on the receiving device.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
}
assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID))
sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID)
@@ -274,10 +273,10 @@ func TestVerification_SAS(t *testing.T) {
// Test the transaction is done on both sides. We have to dispatch
// twice to process and drain all of the events.
- ts.dispatchToDevice(t, ctx, sendingClient)
- ts.dispatchToDevice(t, ctx, receivingClient)
- ts.dispatchToDevice(t, ctx, sendingClient)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
assert.True(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
})
@@ -288,7 +287,6 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -305,10 +303,10 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
err = sendingHelper.StartSAS(ctx, txnID)
require.NoError(t, err)
@@ -325,7 +323,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID)
// Process the start event from the receiving client to the sending client.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 2)
assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID)
@@ -333,13 +331,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// Process the rest of the events until we need to confirm the SAS.
for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 {
- ts.dispatchToDevice(t, ctx, receivingClient)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
}
// Confirm the SAS only the receiving device.
receivingHelper.ConfirmSAS(ctx, txnID)
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Verification is not done until both devices confirm the SAS.
assert.False(t, sendingCallbacks.IsVerificationDone(txnID))
@@ -350,13 +348,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// Dispatching the events to the receiving device should get us to the done
// state on the receiving device.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
assert.False(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
// Dispatching the events to the sending client should get us to the done
// state on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
assert.True(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
}
diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go
index b4c21c18..ce5ec5b4 100644
--- a/crypto/verificationhelper/verificationhelper_test.go
+++ b/crypto/verificationhelper/verificationhelper_test.go
@@ -19,6 +19,7 @@ import (
"maunium.net/go/mautrix/crypto/verificationhelper"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
+ "maunium.net/go/mautrix/mockserver"
)
var aliceUserID = id.UserID("@alice:example.org")
@@ -31,9 +32,19 @@ func init() {
zerolog.DefaultContextLogger = &log.Logger
}
-func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
+func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) {
+ err := cryptoStore.PutDevice(ctx, userID, &id.Device{
+ UserID: userID,
+ DeviceID: deviceID,
+ })
+ if err != nil {
+ panic(err)
+ }
+}
+
+func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
t.Helper()
- ts = createMockServer(t)
+ ts = mockserver.Create(t)
sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID)
sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
@@ -47,9 +58,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServ
return
}
-func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
+func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
t.Helper()
- ts = createMockServer(t)
+ ts = mockserver.Create(t)
sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID)
sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
@@ -116,8 +127,7 @@ func TestVerification_Start(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
- ts := createMockServer(t)
- defer ts.Close()
+ ts := mockserver.Create(t)
client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID)
addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID)
@@ -166,7 +176,6 @@ func TestVerification_StartThenCancel(t *testing.T) {
for _, sendingCancels := range []bool{true, false} {
t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) {
ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID)
@@ -186,13 +195,13 @@ func TestVerification_StartThenCancel(t *testing.T) {
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 1)
assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Process the request event on the bystander device.
bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID]
assert.Len(t, bystanderInbox, 1)
assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID)
- ts.dispatchToDevice(t, ctx, bystanderClient)
+ ts.DispatchToDevice(t, ctx, bystanderClient)
// Cancel the verification request.
var cancelEvt *event.VerificationCancelEventContent
@@ -231,7 +240,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
if !sendingCancels {
// Process the cancellation event on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the cancellation event was sent to the bystander device.
assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1)
@@ -247,8 +256,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
- ts := createMockServer(t)
- defer ts.Close()
+ ts := mockserver.Create(t)
sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID)
receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID)
@@ -274,7 +282,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, txnID)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiver ignored the request because it
// doesn't support any of the verification methods in the
@@ -314,7 +322,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
assert.NoError(t, err)
@@ -333,7 +340,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
require.NoError(t, err)
// Process the verification request on the receiving device.
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device received a verification
// request with the correct transaction ID.
@@ -373,7 +380,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
// Receive the m.key.verification.ready event on the sending
// device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device got a notification about the
// transaction being ready.
@@ -402,7 +409,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
nonParticipatingDeviceID1 := id.DeviceID("non-participating1")
@@ -419,12 +425,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
// the receiving device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
// Receive the m.key.verification.ready event on the sending device.
- ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.DispatchToDevice(t, ctx, sendingClient)
// The sending and receiving devices should not have any cancellation
// events in their inboxes.
@@ -444,7 +450,6 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
_, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
@@ -452,7 +457,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
err = receivingHelper.AcceptVerification(ctx, txnID)
@@ -472,7 +477,6 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
func TestVerification_CancelOnDoubleStart(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
- defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
_, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
@@ -481,15 +485,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) {
// Send and accept the first verification request.
txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID1)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event
+ ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event
// Send a second verification request
txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.DispatchToDevice(t, ctx, receivingClient)
// Ensure that the sending device received a cancellation event for both of
// the ongoing transactions.
@@ -507,7 +511,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) {
assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1))
assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2))
- ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events
+ ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events
assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1))
assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2))
}
diff --git a/error.go b/error.go
index 6f4880df..4711b3dc 100644
--- a/error.go
+++ b/error.go
@@ -13,6 +13,7 @@ import (
"net/http"
"go.mau.fi/util/exhttp"
+ "go.mau.fi/util/exmaps"
"golang.org/x/exp/maps"
)
@@ -66,6 +67,8 @@ var (
MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"}
// The client specified a parameter that has the wrong value.
MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest}
+ // The client specified a room key backup version that is not the current room key backup version for the user.
+ MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden}
MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"}
MBadStatus = RespError{ErrCode: "M_BAD_STATUS"}
@@ -79,6 +82,13 @@ var (
var (
ErrClientIsNil = errors.New("client is nil")
ErrClientHasNoHomeserver = errors.New("client has no homeserver set")
+
+ ErrResponseTooLong = errors.New("response content length too long")
+ ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body")
+
+ // Special error that indicates we should retry canceled contexts. Note that on it's own this
+ // is useless, the context itself must also be replaced.
+ ErrContextCancelRetry = errors.New("retry canceled context")
)
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
@@ -130,7 +140,10 @@ type RespError struct {
Err string
ExtraData map[string]any
- StatusCode int
+ StatusCode int
+ ExtraHeader map[string]string
+
+ CanRetry bool
}
func (e *RespError) UnmarshalJSON(data []byte) error {
@@ -140,16 +153,17 @@ func (e *RespError) UnmarshalJSON(data []byte) error {
}
e.ErrCode, _ = e.ExtraData["errcode"].(string)
e.Err, _ = e.ExtraData["error"].(string)
+ e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool)
return nil
}
func (e *RespError) MarshalJSON() ([]byte, error) {
- data := maps.Clone(e.ExtraData)
- if data == nil {
- data = make(map[string]any)
- }
+ data := exmaps.NonNilClone(e.ExtraData)
data["errcode"] = e.ErrCode
data["error"] = e.Err
+ if e.CanRetry {
+ data["com.beeper.can_retry"] = e.CanRetry
+ }
return json.Marshal(data)
}
@@ -161,6 +175,9 @@ func (e RespError) Write(w http.ResponseWriter) {
if statusCode == 0 {
statusCode = http.StatusInternalServerError
}
+ for key, value := range e.ExtraHeader {
+ w.Header().Set(key, value)
+ }
exhttp.WriteJSONResponse(w, statusCode, &e)
}
@@ -177,6 +194,29 @@ func (e RespError) WithStatus(status int) RespError {
return e
}
+func (e RespError) WithCanRetry(canRetry bool) RespError {
+ e.CanRetry = canRetry
+ return e
+}
+
+func (e RespError) WithExtraData(extraData map[string]any) RespError {
+ e.ExtraData = exmaps.NonNilClone(e.ExtraData)
+ maps.Copy(e.ExtraData, extraData)
+ return e
+}
+
+func (e RespError) WithExtraHeader(key, value string) RespError {
+ e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader)
+ e.ExtraHeader[key] = value
+ return e
+}
+
+func (e RespError) WithExtraHeaders(headers map[string]string) RespError {
+ e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader)
+ maps.Copy(e.ExtraHeader, headers)
+ return e
+}
+
// Error returns the errcode and error message.
func (e RespError) Error() string {
return e.ErrCode + ": " + e.Err
diff --git a/event/accountdata.go b/event/accountdata.go
index 30ca35a2..223919a1 100644
--- a/event/accountdata.go
+++ b/event/accountdata.go
@@ -105,3 +105,15 @@ func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time {
}
return time.Time{}
}
+
+func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration {
+ ts := bmec.GetMutedUntilTime()
+ now := time.Now()
+ if ts.Before(now) {
+ return 0
+ } else if ts == MutedForever {
+ return -1
+ } else {
+ return ts.Sub(now)
+ }
+}
diff --git a/event/beeper.go b/event/beeper.go
index 921e3466..a1a60b35 100644
--- a/event/beeper.go
+++ b/event/beeper.go
@@ -53,6 +53,8 @@ type BeeperMessageStatusEventContent struct {
LastRetry id.EventID `json:"last_retry,omitempty"`
+ TargetTxnID string `json:"relates_to_txn_id,omitempty"`
+
MutateEventKey string `json:"mutate_event_key,omitempty"`
// Indicates the set of users to whom the event was delivered. If nil, then
@@ -86,6 +88,22 @@ type BeeperRoomKeyAckEventContent struct {
FirstMessageIndex int `json:"first_message_index"`
}
+type BeeperChatDeleteEventContent struct {
+ DeleteForEveryone bool `json:"delete_for_everyone,omitempty"`
+ FromMessageRequest bool `json:"from_message_request,omitempty"`
+}
+
+type BeeperAcceptMessageRequestEventContent struct {
+ // Whether this was triggered by a message rather than an explicit event
+ IsImplicit bool `json:"-"`
+}
+
+type BeeperSendStateEventContent struct {
+ Type string `json:"type"`
+ StateKey string `json:"state_key"`
+ Content Content `json:"content"`
+}
+
type IntOrString int
func (ios *IntOrString) UnmarshalJSON(data []byte) error {
@@ -128,6 +146,7 @@ type BeeperLinkPreview struct {
MatchedURL string `json:"matched_url,omitempty"`
ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
+ ImageBlurhash string `json:"matrix:image:blurhash,omitempty"`
}
type BeeperProfileExtra struct {
@@ -147,6 +166,24 @@ type BeeperPerMessageProfile struct {
HasFallback bool `json:"has_fallback,omitempty"`
}
+type BeeperActionMessageType string
+
+const (
+ BeeperActionMessageCall BeeperActionMessageType = "call"
+)
+
+type BeeperActionMessageCallType string
+
+const (
+ BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice"
+ BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video"
+)
+
+type BeeperActionMessage struct {
+ Type BeeperActionMessageType `json:"type"`
+ CallType BeeperActionMessageCallType `json:"call_type,omitempty"`
+}
+
func (content *MessageEventContent) AddPerMessageProfileFallback() {
if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" {
return
@@ -177,6 +214,15 @@ func (content *MessageEventContent) RemovePerMessageProfileFallback() {
}
}
+type BeeperAIStreamEventContent struct {
+ TurnID string `json:"turn_id"`
+ Seq int `json:"seq"`
+ Part map[string]any `json:"part"`
+ TargetEvent id.EventID `json:"target_event,omitempty"`
+ AgentID string `json:"agent_id,omitempty"`
+ RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
+}
+
type BeeperEncodedOrder struct {
order int64
suborder int16
diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts
index 4cf29de7..26aeb347 100644
--- a/event/capabilities.d.ts
+++ b/event/capabilities.d.ts
@@ -16,6 +16,23 @@ export interface RoomFeatures {
* If a message type isn't listed here, it should be treated as support level -2 (will be rejected).
*/
file?: Record
+ /**
+ * Supported state event types and their parameters. Currently, there are no parameters,
+ * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.).
+ *
+ * Events that are not listed or have a support level of zero or below should be treated as unsupported.
+ *
+ * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here.
+ * `m.room.member` will not be listed here, as it's controlled by the member_actions field.
+ * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now.
+ */
+ state?: Record
+ /**
+ * Supported member actions and their support levels.
+ *
+ * Actions that are not listed or have a support level of zero or below should be treated as unsupported.
+ */
+ member_actions?: Record
/** Maximum length of normal text messages. */
max_text_length?: integer
@@ -41,6 +58,8 @@ export interface RoomFeatures {
delete_max_age?: seconds
/** Whether deleting messages just for yourself is supported. No message age limit. */
delete_for_me?: boolean
+ /** Allowed configuration options for disappearing timers. */
+ disappearing_timer?: DisappearingTimerCapability
/** Whether reactions are supported. */
reaction?: CapabilitySupportLevel
@@ -53,10 +72,21 @@ export interface RoomFeatures {
allowed_reactions?: string[]
/** Whether custom emoji reactions are allowed. */
custom_emoji_reactions?: boolean
+
+ /** Whether deleting the chat for yourself is supported. */
+ delete_chat?: boolean
+ /** Whether deleting the chat for all participants is supported. */
+ delete_chat_for_everyone?: boolean
+ /** What can be done with message requests? */
+ message_request?: {
+ accept_with_message?: CapabilitySupportLevel
+ accept_with_button?: CapabilitySupportLevel
+ }
}
declare type integer = number
declare type seconds = integer
+declare type milliseconds = integer
declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application"
declare type MIMETypeOrPattern =
"*/*"
@@ -64,6 +94,21 @@ declare type MIMETypeOrPattern =
| `${MIMEClass}/${string}`
| `${MIMEClass}/${string}; ${string}`
+export enum MemberAction {
+ Ban = "ban",
+ Kick = "kick",
+ Leave = "leave",
+ RevokeInvite = "revoke_invite",
+ Invite = "invite",
+}
+
+declare type EventType = string
+
+// This is an object for future extensibility (e.g. max name/topic length)
+export interface StateFeatures {
+ level: CapabilitySupportLevel
+}
+
export enum CapabilityMsgType {
// Real message types used in the `msgtype` field
Image = "m.image",
@@ -106,6 +151,25 @@ export interface FileFeatures {
view_once?: boolean
}
+export enum DisappearingType {
+ None = "",
+ AfterRead = "after_read",
+ AfterSend = "after_send",
+}
+
+export interface DisappearingTimerCapability {
+ types: DisappearingType[]
+ /** Allowed timer values. If omitted, any timer is allowed. */
+ timers?: milliseconds[]
+ /**
+ * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear
+ *
+ * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't,
+ * so the hardcoded features for Matrix rooms should set this to true, while bridges will not.
+ */
+ omit_empty_timer?: true
+}
+
/**
* The support level for a feature. These are integers rather than booleans
* to accurately represent what the bridge is doing and hopefully make the
diff --git a/event/capabilities.go b/event/capabilities.go
index 9c9eb09a..a86c726b 100644
--- a/event/capabilities.go
+++ b/event/capabilities.go
@@ -18,6 +18,7 @@ import (
"go.mau.fi/util/exerrors"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"golang.org/x/exp/constraints"
"golang.org/x/exp/maps"
)
@@ -27,8 +28,10 @@ type RoomFeatures struct {
// N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
- Formatting FormattingFeatureMap `json:"formatting,omitempty"`
- File FileFeatureMap `json:"file,omitempty"`
+ Formatting FormattingFeatureMap `json:"formatting,omitempty"`
+ File FileFeatureMap `json:"file,omitempty"`
+ State StateFeatureMap `json:"state,omitempty"`
+ MemberActions MemberFeatureMap `json:"member_actions,omitempty"`
MaxTextLength int `json:"max_text_length,omitempty"`
@@ -44,16 +47,23 @@ type RoomFeatures struct {
DeleteForMe bool `json:"delete_for_me,omitempty"`
DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"`
+ DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"`
+
Reaction CapabilitySupportLevel `json:"reaction,omitempty"`
ReactionCount int `json:"reaction_count,omitempty"`
AllowedReactions []string `json:"allowed_reactions,omitempty"`
CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"`
- ReadReceipts bool `json:"read_receipts,omitempty"`
- TypingNotifications bool `json:"typing_notifications,omitempty"`
- Archive bool `json:"archive,omitempty"`
- MarkAsUnread bool `json:"mark_as_unread,omitempty"`
- DeleteChat bool `json:"delete_chat,omitempty"`
+ ReadReceipts bool `json:"read_receipts,omitempty"`
+ TypingNotifications bool `json:"typing_notifications,omitempty"`
+ Archive bool `json:"archive,omitempty"`
+ MarkAsUnread bool `json:"mark_as_unread,omitempty"`
+ DeleteChat bool `json:"delete_chat,omitempty"`
+ DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"`
+
+ MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"`
+
+ PerMessageProfileRelay bool `json:"-"`
}
func (rf *RoomFeatures) GetID() string {
@@ -63,10 +73,120 @@ func (rf *RoomFeatures) GetID() string {
return base64.RawURLEncoding.EncodeToString(rf.Hash())
}
+func (rf *RoomFeatures) Clone() *RoomFeatures {
+ if rf == nil {
+ return nil
+ }
+ clone := *rf
+ clone.File = clone.File.Clone()
+ clone.Formatting = maps.Clone(clone.Formatting)
+ clone.State = clone.State.Clone()
+ clone.MemberActions = clone.MemberActions.Clone()
+ clone.EditMaxAge = ptr.Clone(clone.EditMaxAge)
+ clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge)
+ clone.DisappearingTimer = clone.DisappearingTimer.Clone()
+ clone.AllowedReactions = slices.Clone(clone.AllowedReactions)
+ clone.MessageRequest = clone.MessageRequest.Clone()
+ return &clone
+}
+
+type MemberFeatureMap map[MemberAction]CapabilitySupportLevel
+
+func (mfm MemberFeatureMap) Clone() MemberFeatureMap {
+ return maps.Clone(mfm)
+}
+
+type MemberAction string
+
+const (
+ MemberActionBan MemberAction = "ban"
+ MemberActionKick MemberAction = "kick"
+ MemberActionLeave MemberAction = "leave"
+ MemberActionRevokeInvite MemberAction = "revoke_invite"
+ MemberActionInvite MemberAction = "invite"
+)
+
+type StateFeatureMap map[string]*StateFeatures
+
+func (sfm StateFeatureMap) Clone() StateFeatureMap {
+ dup := maps.Clone(sfm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type StateFeatures struct {
+ Level CapabilitySupportLevel `json:"level"`
+}
+
+func (sf *StateFeatures) Clone() *StateFeatures {
+ if sf == nil {
+ return nil
+ }
+ clone := *sf
+ return &clone
+}
+
+func (sf *StateFeatures) Hash() []byte {
+ return sf.Level.Hash()
+}
+
type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel
type FileFeatureMap map[CapabilityMsgType]*FileFeatures
+func (ffm FileFeatureMap) Clone() FileFeatureMap {
+ dup := maps.Clone(ffm)
+ for key, value := range dup {
+ dup[key] = value.Clone()
+ }
+ return dup
+}
+
+type DisappearingTimerCapability struct {
+ Types []DisappearingType `json:"types"`
+ Timers []jsontime.Milliseconds `json:"timers,omitempty"`
+
+ OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"`
+}
+
+func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability {
+ if dtc == nil {
+ return nil
+ }
+ clone := *dtc
+ clone.Types = slices.Clone(clone.Types)
+ clone.Timers = slices.Clone(clone.Timers)
+ return &clone
+}
+
+func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool {
+ if dtc == nil || content == nil || content.Type == DisappearingTypeNone {
+ return true
+ }
+ return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer))
+}
+
+type MessageRequestFeatures struct {
+ AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"`
+ AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"`
+}
+
+func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures {
+ return ptr.Clone(mrf)
+}
+
+func (mrf *MessageRequestFeatures) Hash() []byte {
+ if mrf == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage)
+ hashValue(hasher, "accept_with_button", mrf.AcceptWithButton)
+ return hasher.Sum(nil)
+}
+
type CapabilityMsgType = MessageType
// Message types which are used for event capability signaling, but aren't real values for the msgtype field.
@@ -216,6 +336,8 @@ func (rf *RoomFeatures) Hash() []byte {
hashMap(hasher, "formatting", rf.Formatting)
hashMap(hasher, "file", rf.File)
+ hashMap(hasher, "state", rf.State)
+ hashMap(hasher, "member_actions", rf.MemberActions)
hashInt(hasher, "max_text_length", rf.MaxTextLength)
@@ -231,6 +353,7 @@ func (rf *RoomFeatures) Hash() []byte {
hashValue(hasher, "delete", rf.Delete)
hashBool(hasher, "delete_for_me", rf.DeleteForMe)
hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get())
+ hashValue(hasher, "disappearing_timer", rf.DisappearingTimer)
hashValue(hasher, "reaction", rf.Reaction)
hashInt(hasher, "reaction_count", rf.ReactionCount)
@@ -245,10 +368,28 @@ func (rf *RoomFeatures) Hash() []byte {
hashBool(hasher, "archive", rf.Archive)
hashBool(hasher, "mark_as_unread", rf.MarkAsUnread)
hashBool(hasher, "delete_chat", rf.DeleteChat)
+ hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone)
+ hashValue(hasher, "message_request", rf.MessageRequest)
return hasher.Sum(nil)
}
+func (dtc *DisappearingTimerCapability) Hash() []byte {
+ if dtc == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hasher.Write([]byte("types"))
+ for _, t := range dtc.Types {
+ hasher.Write([]byte(t))
+ }
+ hasher.Write([]byte("timers"))
+ for _, timer := range dtc.Timers {
+ hashInt(hasher, "", timer.Milliseconds())
+ }
+ return hasher.Sum(nil)
+}
+
func (ff *FileFeatures) Hash() []byte {
hasher := sha256.New()
hashMap(hasher, "mime_types", ff.MimeTypes)
@@ -261,3 +402,13 @@ func (ff *FileFeatures) Hash() []byte {
hashBool(hasher, "view_once", ff.ViewOnce)
return hasher.Sum(nil)
}
+
+func (ff *FileFeatures) Clone() *FileFeatures {
+ if ff == nil {
+ return nil
+ }
+ clone := *ff
+ clone.MimeTypes = maps.Clone(clone.MimeTypes)
+ clone.MaxDuration = ptr.Clone(clone.MaxDuration)
+ return &clone
+}
diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go
new file mode 100644
index 00000000..ce07c4c0
--- /dev/null
+++ b/event/cmdschema/content.go
@@ -0,0 +1,78 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "reflect"
+ "slices"
+
+ "go.mau.fi/util/exsync"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type EventContent struct {
+ Command string `json:"command"`
+ Aliases []string `json:"aliases,omitempty"`
+ Parameters []*Parameter `json:"parameters,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ TailParam string `json:"fi.mau.tail_parameter,omitempty"`
+}
+
+func (ec *EventContent) Validate() error {
+ if ec == nil {
+ return fmt.Errorf("event content is nil")
+ } else if ec.Command == "" {
+ return fmt.Errorf("command is empty")
+ }
+ var tailFound bool
+ dupMap := exsync.NewSet[string]()
+ for i, p := range ec.Parameters {
+ if err := p.Validate(); err != nil {
+ return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err)
+ } else if !dupMap.Add(p.Key) {
+ return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1)
+ } else if p.Key == ec.TailParam {
+ tailFound = true
+ } else if tailFound && !p.Optional {
+ return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam)
+ }
+ }
+ if ec.TailParam != "" && !tailFound {
+ return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam)
+ }
+ return nil
+}
+
+func (ec *EventContent) IsValid() bool {
+ return ec.Validate() == nil
+}
+
+func (ec *EventContent) StateKey(owner id.UserID) string {
+ hash := sha256.Sum256([]byte(ec.Command + owner.String()))
+ return base64.StdEncoding.EncodeToString(hash[:])
+}
+
+func (ec *EventContent) Equals(other *EventContent) bool {
+ if ec == nil || other == nil {
+ return ec == other
+ }
+ return ec.Command == other.Command &&
+ slices.Equal(ec.Aliases, other.Aliases) &&
+ slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) &&
+ ec.Description.Equals(other.Description) &&
+ ec.TailParam == other.TailParam
+}
+
+func init() {
+ event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{})
+}
diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go
new file mode 100644
index 00000000..4193b297
--- /dev/null
+++ b/event/cmdschema/parameter.go
@@ -0,0 +1,286 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+
+ "go.mau.fi/util/exslices"
+
+ "maunium.net/go/mautrix/event"
+)
+
+type Parameter struct {
+ Key string `json:"key"`
+ Schema *ParameterSchema `json:"schema"`
+ Optional bool `json:"optional,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ DefaultValue any `json:"fi.mau.default_value,omitempty"`
+}
+
+func (p *Parameter) Equals(other *Parameter) bool {
+ if p == nil || other == nil {
+ return p == other
+ }
+ return p.Key == other.Key &&
+ p.Schema.Equals(other.Schema) &&
+ p.Optional == other.Optional &&
+ p.Description.Equals(other.Description) &&
+ p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values
+}
+
+func (p *Parameter) Validate() error {
+ if p == nil {
+ return fmt.Errorf("parameter is nil")
+ } else if p.Key == "" {
+ return fmt.Errorf("key is empty")
+ }
+ return p.Schema.Validate()
+}
+
+func (p *Parameter) IsValid() bool {
+ return p.Validate() == nil
+}
+
+func (p *Parameter) GetDefaultValue() any {
+ if p != nil && p.DefaultValue != nil {
+ return p.DefaultValue
+ } else if p == nil || p.Optional {
+ return nil
+ }
+ return p.Schema.GetDefaultValue()
+}
+
+type PrimitiveType string
+
+const (
+ PrimitiveTypeString PrimitiveType = "string"
+ PrimitiveTypeInteger PrimitiveType = "integer"
+ PrimitiveTypeBoolean PrimitiveType = "boolean"
+ PrimitiveTypeServerName PrimitiveType = "server_name"
+ PrimitiveTypeUserID PrimitiveType = "user_id"
+ PrimitiveTypeRoomID PrimitiveType = "room_id"
+ PrimitiveTypeRoomAlias PrimitiveType = "room_alias"
+ PrimitiveTypeEventID PrimitiveType = "event_id"
+)
+
+func (pt PrimitiveType) Schema() *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypePrimitive,
+ Type: pt,
+ }
+}
+
+func (pt PrimitiveType) IsValid() bool {
+ switch pt {
+ case PrimitiveTypeString,
+ PrimitiveTypeInteger,
+ PrimitiveTypeBoolean,
+ PrimitiveTypeServerName,
+ PrimitiveTypeUserID,
+ PrimitiveTypeRoomID,
+ PrimitiveTypeRoomAlias,
+ PrimitiveTypeEventID:
+ return true
+ default:
+ return false
+ }
+}
+
+type SchemaType string
+
+const (
+ SchemaTypePrimitive SchemaType = "primitive"
+ SchemaTypeArray SchemaType = "array"
+ SchemaTypeUnion SchemaType = "union"
+ SchemaTypeLiteral SchemaType = "literal"
+)
+
+type ParameterSchema struct {
+ SchemaType SchemaType `json:"schema_type"`
+ Type PrimitiveType `json:"type,omitempty"` // Only for primitive
+ Items *ParameterSchema `json:"items,omitempty"` // Only for array
+ Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union
+ Value any `json:"value,omitempty"` // Only for literal
+}
+
+func Literal(value any) *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypeLiteral,
+ Value: value,
+ }
+}
+
+func Enum(values ...any) *ParameterSchema {
+ return Union(exslices.CastFunc(values, Literal)...)
+}
+
+func flattenUnion(variants []*ParameterSchema) []*ParameterSchema {
+ var flattened []*ParameterSchema
+ for _, variant := range variants {
+ switch variant.SchemaType {
+ case SchemaTypeArray:
+ panic(fmt.Errorf("illegal array schema in union"))
+ case SchemaTypeUnion:
+ flattened = append(flattened, flattenUnion(variant.Variants)...)
+ default:
+ flattened = append(flattened, variant)
+ }
+ }
+ return flattened
+}
+
+func Union(variants ...*ParameterSchema) *ParameterSchema {
+ needsFlattening := false
+ for _, variant := range variants {
+ if variant.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in union"))
+ } else if variant.SchemaType == SchemaTypeUnion {
+ needsFlattening = true
+ }
+ }
+ if needsFlattening {
+ variants = flattenUnion(variants)
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeUnion,
+ Variants: variants,
+ }
+}
+
+func Array(items *ParameterSchema) *ParameterSchema {
+ if items.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in array"))
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeArray,
+ Items: items,
+ }
+}
+
+func (ps *ParameterSchema) GetDefaultValue() any {
+ if ps == nil {
+ return nil
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ switch ps.Type {
+ case PrimitiveTypeInteger:
+ return 0
+ case PrimitiveTypeBoolean:
+ return false
+ default:
+ return ""
+ }
+ case SchemaTypeArray:
+ return []any{}
+ case SchemaTypeUnion:
+ if len(ps.Variants) > 0 {
+ return ps.Variants[0].GetDefaultValue()
+ }
+ return nil
+ case SchemaTypeLiteral:
+ return ps.Value
+ default:
+ return nil
+ }
+}
+
+func (ps *ParameterSchema) IsValid() bool {
+ return ps.validate("") == nil
+}
+
+func (ps *ParameterSchema) Validate() error {
+ return ps.validate("")
+}
+
+func (ps *ParameterSchema) validate(parent SchemaType) error {
+ if ps == nil {
+ return fmt.Errorf("schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ if !ps.Type.IsValid() {
+ return fmt.Errorf("invalid primitive type %s", ps.Type)
+ } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("primitive schema has extra fields")
+ }
+ return nil
+ case SchemaTypeArray:
+ if parent != "" {
+ return fmt.Errorf("arrays can't be nested in other types")
+ } else if err := ps.Items.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("item schema is invalid: %w", err)
+ } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("array schema has extra fields")
+ }
+ return nil
+ case SchemaTypeUnion:
+ if len(ps.Variants) == 0 {
+ return fmt.Errorf("no variants specified for union")
+ } else if parent != "" && parent != SchemaTypeArray {
+ return fmt.Errorf("unions can't be nested in anything other than arrays")
+ }
+ for i, v := range ps.Variants {
+ if err := v.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("variant #%d is invalid: %w", i+1, err)
+ }
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Value != nil {
+ return fmt.Errorf("union schema has extra fields")
+ }
+ return nil
+ case SchemaTypeLiteral:
+ switch typedVal := ps.Value.(type) {
+ case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue:
+ // ok
+ case map[string]any:
+ if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" {
+ return fmt.Errorf("literal value has invalid map data")
+ }
+ default:
+ return fmt.Errorf("literal value has unsupported type %T", ps.Value)
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Variants != nil {
+ return fmt.Errorf("literal schema has extra fields")
+ }
+ return nil
+ default:
+ return fmt.Errorf("invalid schema type %s", ps.SchemaType)
+ }
+}
+
+func (ps *ParameterSchema) Equals(other *ParameterSchema) bool {
+ if ps == nil || other == nil {
+ return ps == other
+ }
+ return ps.SchemaType == other.SchemaType &&
+ ps.Type == other.Type &&
+ ps.Items.Equals(other.Items) &&
+ slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) &&
+ ps.Value == other.Value // TODO this won't work for room/event ID values
+}
+
+func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool {
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type == prim
+ case SchemaTypeUnion:
+ for _, variant := range ps.Variants {
+ if variant.AllowsPrimitive(prim) {
+ return true
+ }
+ }
+ return false
+ case SchemaTypeArray:
+ return ps.Items.AllowsPrimitive(prim)
+ default:
+ return false
+ }
+}
diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go
new file mode 100644
index 00000000..92e69b60
--- /dev/null
+++ b/event/cmdschema/parse.go
@@ -0,0 +1,478 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+const botArrayOpener = "<"
+const botArrayCloser = ">"
+
+func parseQuoted(val string) (parsed, remaining string, quoted bool) {
+ if len(val) == 0 {
+ return
+ }
+ if !strings.HasPrefix(val, `"`) {
+ spaceIdx := strings.IndexByte(val, ' ')
+ if spaceIdx == -1 {
+ parsed = val
+ } else {
+ parsed = val[:spaceIdx]
+ remaining = strings.TrimLeft(val[spaceIdx+1:], " ")
+ }
+ return
+ }
+ val = val[1:]
+ var buf strings.Builder
+ for {
+ quoteIdx := strings.IndexByte(val, '"')
+ var valUntilQuote string
+ if quoteIdx == -1 {
+ valUntilQuote = val
+ } else {
+ valUntilQuote = val[:quoteIdx]
+ }
+ escapeIdx := strings.IndexByte(valUntilQuote, '\\')
+ if escapeIdx >= 0 {
+ buf.WriteString(val[:escapeIdx])
+ if len(val) > escapeIdx+1 {
+ buf.WriteByte(val[escapeIdx+1])
+ }
+ val = val[min(escapeIdx+2, len(val)):]
+ } else if quoteIdx >= 0 {
+ buf.WriteString(val[:quoteIdx])
+ val = val[quoteIdx+1:]
+ break
+ } else if buf.Len() == 0 {
+ // Unterminated quote, no escape characters, val is the whole input
+ return val, "", true
+ } else {
+ // Unterminated quote, but there were escape characters previously
+ buf.WriteString(val)
+ val = ""
+ break
+ }
+ }
+ return buf.String(), strings.TrimLeft(val, " "), true
+}
+
+// ParseInput tries to parse the given text into a bot command event matching this command definition.
+//
+// If the prefix doesn't match, this will return a nil content and nil error.
+// If the prefix does match, some content is always returned, but there may still be an error if parsing failed.
+func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) {
+ prefix := ec.parsePrefix(input, sigils, owner.String())
+ if prefix == "" {
+ return nil, nil
+ }
+ content = &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: input,
+ Mentions: &event.Mentions{UserIDs: []id.UserID{owner}},
+ MSC4391BotCommand: &event.MSC4391BotCommandInput{
+ Command: ec.Command,
+ },
+ }
+ content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):])
+ return content, err
+}
+
+func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) {
+ args := make(map[string]any)
+ var retErr error
+ setError := func(err error) {
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+ processParameter := func(param *Parameter, isLast, isTail, isNamed bool) {
+ origInput := input
+ var nextVal string
+ var wasQuoted bool
+ if param.Schema.SchemaType == SchemaTypeArray {
+ hasOpener := strings.HasPrefix(input, botArrayOpener)
+ arrayClosed := false
+ if hasOpener {
+ input = input[len(botArrayOpener):]
+ if strings.HasPrefix(input, botArrayCloser) {
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ }
+ }
+ var collector []any
+ for len(input) > 0 && !arrayClosed {
+ //origInput = input
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) {
+ // The value wasn't quoted and has the array delimiter at the end, close the array
+ nextVal = strings.TrimRight(nextVal, botArrayCloser)
+ arrayClosed = true
+ } else if hasOpener && strings.HasPrefix(input, botArrayCloser) {
+ // The value was quoted or there was a space, and the next character is the
+ // array delimiter, close the array
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ } else if !hasOpener && !isLast {
+ // For array arguments in the middle without the <> delimiters, stop after the first item
+ arrayClosed = true
+ }
+ parsedVal, err := param.Schema.Items.ParseString(nextVal)
+ if err == nil {
+ collector = append(collector, parsedVal)
+ } else if hasOpener || isLast {
+ setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err))
+ } else {
+ //input = origInput
+ }
+ }
+ args[param.Key] = collector
+ } else {
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if (isLast || isTail) && !wasQuoted && len(input) > 0 {
+ // If the last argument is not quoted, just treat the rest of the string
+ // as the argument without escapes (arguments with escapes should be quoted).
+ nextVal += " " + input
+ input = ""
+ }
+ // Special case for named boolean parameters: if no value is given, treat it as true
+ if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) {
+ args[param.Key] = true
+ return
+ }
+ if nextVal == "" && !wasQuoted && !isNamed && !param.Optional {
+ setError(fmt.Errorf("missing value for required parameter %s", param.Key))
+ }
+ parsedVal, err := param.Schema.ParseString(nextVal)
+ if err != nil {
+ args[param.Key] = param.GetDefaultValue()
+ // For optional parameters that fail to parse, restore the input and try passing it as the next parameter
+ if param.Optional && !isLast && !isNamed {
+ input = strings.TrimLeft(origInput, " ")
+ } else if !param.Optional || isNamed {
+ setError(fmt.Errorf("failed to parse %s: %w", param.Key, err))
+ }
+ } else {
+ args[param.Key] = parsedVal
+ }
+ }
+ }
+ skipParams := make([]bool, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ for strings.HasPrefix(input, "--") {
+ nameEndIdx := strings.IndexAny(input, " =")
+ if nameEndIdx == -1 {
+ nameEndIdx = len(input)
+ }
+ overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx])
+ if overrideParam != nil {
+ // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input
+ input = strings.TrimPrefix(input[nameEndIdx:], "=")
+ skipParams[paramIdx] = true
+ processParameter(overrideParam, false, false, true)
+ } else {
+ break
+ }
+ }
+ isTail := param.Key == ec.TailParam
+ if skipParams[i] || (param.Optional && !isTail) {
+ continue
+ }
+ processParameter(param, i == len(ec.Parameters)-1, isTail, false)
+ }
+ jsonArgs, marshalErr := json.Marshal(args)
+ if marshalErr != nil {
+ return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr)
+ }
+ return jsonArgs, retErr
+}
+
+func (ec *EventContent) parameterByName(name string) (*Parameter, int) {
+ for i, param := range ec.Parameters {
+ if strings.EqualFold(param.Key, name) {
+ return param, i
+ }
+ }
+ return nil, -1
+}
+
+func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) {
+ input := origInput
+ var chosenSigil string
+ for _, sigil := range sigils {
+ if strings.HasPrefix(input, sigil) {
+ chosenSigil = sigil
+ break
+ }
+ }
+ if chosenSigil == "" {
+ return ""
+ }
+ input = input[len(chosenSigil):]
+ var chosenAlias string
+ if !strings.HasPrefix(input, ec.Command) {
+ for _, alias := range ec.Aliases {
+ if strings.HasPrefix(input, alias) {
+ chosenAlias = alias
+ break
+ }
+ }
+ if chosenAlias == "" {
+ return ""
+ }
+ } else {
+ chosenAlias = ec.Command
+ }
+ input = strings.TrimPrefix(input[len(chosenAlias):], owner)
+ if input == "" || input[0] == ' ' {
+ input = strings.TrimLeft(input, " ")
+ return origInput[:len(origInput)-len(input)]
+ }
+ return ""
+}
+
+func (pt PrimitiveType) ValidateValue(value any) bool {
+ _, err := pt.NormalizeValue(value)
+ return err == nil
+}
+
+func normalizeNumber(value any) (int, error) {
+ switch typedValue := value.(type) {
+ case int:
+ return typedValue, nil
+ case int64:
+ return int(typedValue), nil
+ case float64:
+ return int(typedValue), nil
+ case json.Number:
+ if i, err := typedValue.Int64(); err != nil {
+ return 0, fmt.Errorf("failed to parse json.Number: %w", err)
+ } else {
+ return int(i), nil
+ }
+ default:
+ return 0, fmt.Errorf("unsupported type %T for integer", value)
+ }
+}
+
+func (pt PrimitiveType) NormalizeValue(value any) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return normalizeNumber(value)
+ case PrimitiveTypeBoolean:
+ bv, ok := value.(bool)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for boolean", value)
+ }
+ return bv, nil
+ case PrimitiveTypeString, PrimitiveTypeServerName:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for string", value)
+ }
+ return str, pt.validateStringValue(str)
+ case PrimitiveTypeUserID, PrimitiveTypeRoomAlias:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value)
+ } else if plainErr := pt.validateStringValue(str); plainErr == nil {
+ return str, nil
+ } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID {
+ return parsed.UserID(), nil
+ } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias {
+ return parsed.RoomAlias(), nil
+ } else {
+ return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1)
+ }
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ riv, err := NormalizeRoomIDValue(value)
+ if err != nil {
+ return nil, err
+ }
+ return riv, riv.Validate()
+ default:
+ return nil, fmt.Errorf("cannot normalize value for argument type %s", pt)
+ }
+}
+
+func (pt PrimitiveType) validateStringValue(value string) error {
+ switch pt {
+ case PrimitiveTypeString:
+ return nil
+ case PrimitiveTypeServerName:
+ if !id.ValidateServerName(value) {
+ return fmt.Errorf("invalid server name: %q", value)
+ }
+ return nil
+ case PrimitiveTypeUserID:
+ _, _, err := id.UserID(value).ParseAndValidateRelaxed()
+ return err
+ case PrimitiveTypeRoomAlias:
+ sigil, localpart, serverName := id.ParseCommonIdentifier(value)
+ if sigil != '#' || localpart == "" || serverName == "" {
+ return fmt.Errorf("invalid room alias: %q", value)
+ } else if !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name in room alias: %q", serverName)
+ }
+ return nil
+ default:
+ panic(fmt.Errorf("validateStringValue called with invalid type %s", pt))
+ }
+}
+
+func parseBoolean(val string) (bool, error) {
+ if len(val) == 0 {
+ return false, fmt.Errorf("cannot parse empty string as boolean")
+ }
+ switch strings.ToLower(val) {
+ case "t", "true", "y", "yes", "1":
+ return true, nil
+ case "f", "false", "n", "no", "0":
+ return false, nil
+ default:
+ return false, fmt.Errorf("invalid boolean string: %q", val)
+ }
+}
+
+var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`)
+
+func parseRoomOrEventID(value string) (*RoomIDValue, error) {
+ if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") {
+ matches := markdownLinkRegex.FindStringSubmatch(value)
+ if len(matches) == 2 {
+ value = matches[1]
+ }
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil && strings.HasPrefix(value, "!") {
+ return &RoomIDValue{
+ Type: PrimitiveTypeRoomID,
+ RoomID: id.RoomID(value),
+ }, nil
+ }
+ if err != nil {
+ return nil, err
+ } else if parsed.Sigil1 != '!' {
+ return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1)
+ } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' {
+ return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2)
+ }
+ valType := PrimitiveTypeRoomID
+ if parsed.MXID2 != "" {
+ valType = PrimitiveTypeEventID
+ }
+ return &RoomIDValue{
+ Type: valType,
+ RoomID: parsed.RoomID(),
+ Via: parsed.Via,
+ EventID: parsed.EventID(),
+ }, nil
+}
+
+func (pt PrimitiveType) ParseString(value string) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return strconv.Atoi(value)
+ case PrimitiveTypeBoolean:
+ return parseBoolean(value)
+ case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID:
+ return value, pt.validateStringValue(value)
+ case PrimitiveTypeRoomAlias:
+ plainErr := pt.validateStringValue(value)
+ if plainErr == nil {
+ return value, nil
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 != '#' {
+ return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1)
+ }
+ return parsed.RoomAlias(), nil
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, err
+ } else if pt != parsed.Type {
+ return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("cannot parse string for argument type %s", pt)
+ }
+}
+
+func (ps *ParameterSchema) ParseString(value string) (any, error) {
+ if ps == nil {
+ return nil, fmt.Errorf("parameter schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type.ParseString(value)
+ case SchemaTypeLiteral:
+ switch typedValue := ps.Value.(type) {
+ case string:
+ if value == typedValue {
+ return typedValue, nil
+ } else {
+ return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value)
+ }
+ case int, int64, float64, json.Number:
+ expectedVal, _ := normalizeNumber(typedValue)
+ intVal, err := strconv.Atoi(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse integer literal: %w", err)
+ } else if intVal != expectedVal {
+ return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal)
+ }
+ return intVal, nil
+ case bool:
+ boolVal, err := parseBoolean(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse boolean literal: %w", err)
+ } else if boolVal != typedValue {
+ return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal)
+ }
+ return boolVal, nil
+ case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage:
+ expectedVal, _ := NormalizeRoomIDValue(typedValue)
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err)
+ } else if !parsed.Equals(expectedVal) {
+ return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("unsupported literal type %T", ps.Value)
+ }
+ case SchemaTypeUnion:
+ var errs []error
+ for _, variant := range ps.Variants {
+ if parsed, err := variant.ParseString(value); err == nil {
+ return parsed, nil
+ } else {
+ errs = append(errs, err)
+ }
+ }
+ return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...))
+ case SchemaTypeArray:
+ return nil, fmt.Errorf("cannot parse string for array schema type")
+ default:
+ return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType)
+ }
+}
diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go
new file mode 100644
index 00000000..1e0d1817
--- /dev/null
+++ b/event/cmdschema/parse_test.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "bytes"
+ "encoding/json"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exbytes"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/event/cmdschema/testdata"
+)
+
+type QuoteParseOutput struct {
+ Parsed string
+ Remaining string
+ Quoted bool
+}
+
+func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error {
+ var arr []any
+ if err := json.Unmarshal(data, &arr); err != nil {
+ return err
+ }
+ qpo.Parsed = arr[0].(string)
+ qpo.Remaining = arr[1].(string)
+ qpo.Quoted = arr[2].(bool)
+ return nil
+}
+
+type QuoteParseTestData struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Output QuoteParseOutput `json:"output"`
+}
+
+func loadFile[T any](name string) (into T) {
+ quoteData := exerrors.Must(testdata.FS.ReadFile(name))
+ exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into))
+ return
+}
+
+func TestParseQuoted(t *testing.T) {
+ qptd := loadFile[[]QuoteParseTestData]("parse_quote.json")
+ for _, test := range qptd {
+ t.Run(test.Name, func(t *testing.T) {
+ parsed, remaining, quoted := parseQuoted(test.Input)
+ assert.Equalf(t, test.Output, QuoteParseOutput{
+ Parsed: parsed,
+ Remaining: remaining,
+ Quoted: quoted,
+ }, "Failed with input `%s`", test.Input)
+ // Note: can't just test that requoted == input, because some inputs
+ // have unnecessary escapes which won't survive roundtripping
+ t.Run("roundtrip", func(t *testing.T) {
+ requoted := quoteString(parsed) + " " + remaining
+ reparsed, newRemaining, _ := parseQuoted(requoted)
+ assert.Equal(t, parsed, reparsed)
+ assert.Equal(t, remaining, newRemaining)
+ })
+ })
+ }
+}
+
+type CommandTestData struct {
+ Spec *EventContent
+ Tests []*CommandTestUnit
+}
+
+type CommandTestUnit struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Broken string `json:"broken,omitempty"`
+ Error bool `json:"error"`
+ Output json.RawMessage `json:"output"`
+}
+
+func compactJSON(input json.RawMessage) json.RawMessage {
+ var buf bytes.Buffer
+ exerrors.PanicIfNotNil(json.Compact(&buf, input))
+ return buf.Bytes()
+}
+
+func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) {
+ for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) {
+ t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) {
+ ctd := loadFile[CommandTestData]("commands/" + cmd.Name())
+ for _, test := range ctd.Tests {
+ outputStr := exbytes.UnsafeString(compactJSON(test.Output))
+ t.Run(test.Name, func(t *testing.T) {
+ if test.Broken != "" {
+ t.Skip(test.Broken)
+ }
+ output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input)
+ if test.Error {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ if outputStr == "null" {
+ assert.Nil(t, output)
+ } else {
+ assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command)
+ assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go
new file mode 100644
index 00000000..98c421fc
--- /dev/null
+++ b/event/cmdschema/roomid.go
@@ -0,0 +1,135 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "strings"
+
+ "maunium.net/go/mautrix/id"
+)
+
+var ParameterSchemaJoinableRoom = Union(
+ PrimitiveTypeRoomID.Schema(),
+ PrimitiveTypeRoomAlias.Schema(),
+)
+
+type RoomIDValue struct {
+ Type PrimitiveType `json:"type"`
+ RoomID id.RoomID `json:"id"`
+ Via []string `json:"via,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+}
+
+func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) {
+ switch typedValue := input.(type) {
+ case map[string]any, json.RawMessage:
+ var raw json.RawMessage
+ if raw, err = json.Marshal(input); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ } else if err = json.Unmarshal(raw, &riv); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ }
+ case *RoomIDValue:
+ riv = typedValue
+ case RoomIDValue:
+ riv = &typedValue
+ default:
+ err = fmt.Errorf("unsupported type %T for room or event ID", input)
+ }
+ return
+}
+
+func (riv *RoomIDValue) String() string {
+ return riv.URI().String()
+}
+
+func (riv *RoomIDValue) URI() *id.MatrixURI {
+ if riv == nil {
+ return nil
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ return riv.RoomID.URI(riv.Via...)
+ case PrimitiveTypeEventID:
+ return riv.RoomID.EventURI(riv.EventID, riv.Via...)
+ default:
+ return nil
+ }
+}
+
+func (riv *RoomIDValue) Equals(other *RoomIDValue) bool {
+ if riv == nil || other == nil {
+ return riv == other
+ }
+ return riv.Type == other.Type &&
+ riv.RoomID == other.RoomID &&
+ riv.EventID == other.EventID &&
+ slices.Equal(riv.Via, other.Via)
+}
+
+func (riv *RoomIDValue) Validate() error {
+ if riv == nil {
+ return fmt.Errorf("value is nil")
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ if riv.EventID != "" {
+ return fmt.Errorf("event ID must be empty for room ID type")
+ }
+ case PrimitiveTypeEventID:
+ if !strings.HasPrefix(riv.EventID.String(), "$") {
+ return fmt.Errorf("event ID not valid: %q", riv.EventID)
+ }
+ default:
+ return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type)
+ }
+ for _, via := range riv.Via {
+ if !id.ValidateServerName(via) {
+ return fmt.Errorf("invalid server name %q in vias", via)
+ }
+ }
+ sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID)
+ if sigil != '!' {
+ return fmt.Errorf("room ID does not start with !: %q", riv.RoomID)
+ } else if localpart == "" && serverName == "" {
+ return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID)
+ } else if serverName != "" && !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name %q in room ID", serverName)
+ }
+ return nil
+}
+
+func (riv *RoomIDValue) IsValid() bool {
+ return riv.Validate() == nil
+}
+
+type RoomIDOrString string
+
+func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error {
+ if len(data) == 0 {
+ return fmt.Errorf("empty data for room ID or string")
+ }
+ if data[0] == '"' {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(str)
+ return nil
+ }
+ var riv RoomIDValue
+ if err := json.Unmarshal(data, &riv); err != nil {
+ return err
+ } else if err = riv.Validate(); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(riv.String())
+ return nil
+}
diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go
new file mode 100644
index 00000000..c5c57c53
--- /dev/null
+++ b/event/cmdschema/stringify.go
@@ -0,0 +1,122 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "strconv"
+ "strings"
+)
+
+var quoteEscaper = strings.NewReplacer(
+ `"`, `\"`,
+ `\`, `\\`,
+)
+
+const charsToQuote = ` \` + botArrayOpener + botArrayCloser
+
+func quoteString(val string) string {
+ if val == "" {
+ return `""`
+ }
+ val = quoteEscaper.Replace(val)
+ if strings.ContainsAny(val, charsToQuote) {
+ return `"` + val + `"`
+ }
+ return val
+}
+
+func (ec *EventContent) StringifyArgs(args any) string {
+ var argMap map[string]any
+ switch typedArgs := args.(type) {
+ case json.RawMessage:
+ err := json.Unmarshal(typedArgs, &argMap)
+ if err != nil {
+ return ""
+ }
+ case map[string]any:
+ argMap = typedArgs
+ default:
+ if b, err := json.Marshal(args); err != nil {
+ return ""
+ } else if err = json.Unmarshal(b, &argMap); err != nil {
+ return ""
+ }
+ }
+ parts := make([]string, 0, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ isLast := i == len(ec.Parameters)-1
+ val := argMap[param.Key]
+ if val == nil {
+ val = param.DefaultValue
+ if val == nil && !param.Optional {
+ val = param.Schema.GetDefaultValue()
+ }
+ }
+ if val == nil {
+ continue
+ }
+ var stringified string
+ if param.Schema.SchemaType == SchemaTypeArray {
+ stringified = arrayArgumentToString(val, isLast)
+ } else {
+ stringified = singleArgumentToString(val)
+ }
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ return strings.Join(parts, " ")
+}
+
+func arrayArgumentToString(val any, isLast bool) string {
+ valArr, ok := val.([]any)
+ if !ok {
+ return ""
+ }
+ parts := make([]string, 0, len(valArr))
+ for _, elem := range valArr {
+ stringified := singleArgumentToString(elem)
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ joinedParts := strings.Join(parts, " ")
+ if isLast && len(parts) > 0 {
+ return joinedParts
+ }
+ return botArrayOpener + joinedParts + botArrayCloser
+}
+
+func singleArgumentToString(val any) string {
+ switch typedVal := val.(type) {
+ case string:
+ return quoteString(typedVal)
+ case json.Number:
+ return typedVal.String()
+ case bool:
+ return strconv.FormatBool(typedVal)
+ case int:
+ return strconv.Itoa(typedVal)
+ case int64:
+ return strconv.FormatInt(typedVal, 10)
+ case float64:
+ return strconv.FormatInt(int64(typedVal), 10)
+ case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue:
+ normalized, err := NormalizeRoomIDValue(typedVal)
+ if err != nil {
+ return ""
+ }
+ uri := normalized.URI()
+ if uri == nil {
+ return ""
+ }
+ return quoteString(uri.String())
+ default:
+ return ""
+ }
+}
diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json
new file mode 100644
index 00000000..e53382db
--- /dev/null
+++ b/event/cmdschema/testdata/commands.schema.json
@@ -0,0 +1,281 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "commands.schema.json",
+ "title": "ParseInput test cases",
+ "description": "JSON schema for test case files containing command specifications and test cases",
+ "type": "object",
+ "required": [
+ "spec",
+ "tests"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "spec": {
+ "title": "MSC4391 Command Description",
+ "description": "JSON schema defining the structure of a bot command event content",
+ "type": "object",
+ "required": [
+ "command"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "The command name that triggers this bot command"
+ },
+ "aliases": {
+ "type": "array",
+ "description": "Alternative names/aliases for this command",
+ "items": {
+ "type": "string"
+ }
+ },
+ "parameters": {
+ "type": "array",
+ "description": "List of parameters accepted by this command",
+ "items": {
+ "$ref": "#/$defs/Parameter"
+ }
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of the command"
+ },
+ "fi.mau.tail_parameter": {
+ "type": "string",
+ "description": "The key of the parameter that accepts remaining arguments as tail text"
+ },
+ "source": {
+ "type": "string",
+ "description": "The user ID of the bot that responds to this command"
+ }
+ }
+ },
+ "tests": {
+ "type": "array",
+ "description": "Array of test cases for the command",
+ "items": {
+ "type": "object",
+ "description": "A single test case for command parsing",
+ "required": [
+ "name",
+ "input"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "The command input string to parse"
+ },
+ "output": {
+ "description": "The expected parsed parameter values, or null if the parsing is expected to fail",
+ "oneOf": [
+ {
+ "type": "object",
+ "additionalProperties": true
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
+ "error": {
+ "type": "boolean",
+ "description": "Whether parsing should result in an error. May still produce output.",
+ "default": false
+ }
+ }
+ }
+ }
+ },
+ "$defs": {
+ "ExtensibleTextContainer": {
+ "type": "object",
+ "description": "Container for text that can have multiple representations",
+ "required": [
+ "m.text"
+ ],
+ "properties": {
+ "m.text": {
+ "type": "array",
+ "description": "Array of text representations in different formats",
+ "items": {
+ "$ref": "#/$defs/ExtensibleText"
+ }
+ }
+ }
+ },
+ "ExtensibleText": {
+ "type": "object",
+ "description": "A text representation with a specific MIME type",
+ "required": [
+ "body"
+ ],
+ "properties": {
+ "body": {
+ "type": "string",
+ "description": "The text content"
+ },
+ "mimetype": {
+ "type": "string",
+ "description": "The MIME type of the text (e.g., text/plain, text/html)",
+ "default": "text/plain",
+ "examples": [
+ "text/plain",
+ "text/html"
+ ]
+ }
+ }
+ },
+ "Parameter": {
+ "type": "object",
+ "description": "A parameter definition for a command",
+ "required": [
+ "key",
+ "schema"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "The identifier for this parameter"
+ },
+ "schema": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema defining the type and structure of this parameter"
+ },
+ "optional": {
+ "type": "boolean",
+ "description": "Whether this parameter is optional",
+ "default": false
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of this parameter"
+ },
+ "fi.mau.default_value": {
+ "description": "Default value for this parameter if not provided"
+ }
+ }
+ },
+ "ParameterSchema": {
+ "type": "object",
+ "description": "Schema definition for a parameter value",
+ "required": [
+ "schema_type"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "schema_type": {
+ "type": "string",
+ "enum": [
+ "primitive",
+ "array",
+ "union",
+ "literal"
+ ],
+ "description": "The type of schema"
+ }
+ },
+ "allOf": [
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "primitive"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "type"
+ ],
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "string",
+ "integer",
+ "boolean",
+ "server_name",
+ "user_id",
+ "room_id",
+ "room_alias",
+ "event_id"
+ ],
+ "description": "The primitive type (only for schema_type: primitive)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "array"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "items"
+ ],
+ "properties": {
+ "items": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema for array items (only for schema_type: array)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "union"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "variants"
+ ],
+ "properties": {
+ "variants": {
+ "type": "array",
+ "description": "The possible variants (only for schema_type: union)",
+ "items": {
+ "$ref": "#/$defs/ParameterSchema"
+ },
+ "minItems": 1
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "literal"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "value"
+ ],
+ "properties": {
+ "value": {
+ "description": "The literal value (only for schema_type: literal)"
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+}
diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json
new file mode 100644
index 00000000..6ce1f4da
--- /dev/null
+++ b/event/cmdschema/testdata/commands/flags.json
@@ -0,0 +1,126 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "flag",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "user",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "user_id"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true,
+ "fi.mau.default_value": false
+ }
+ ],
+ "fi.mau.tail_parameter": "user"
+ },
+ "tests": [
+ {
+ "name": "no flags",
+ "input": "/flag mrrp",
+ "output": {
+ "meow": "mrrp",
+ "user": null
+ }
+ },
+ {
+ "name": "no flags, has tail",
+ "input": "/flag mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com"
+ }
+ },
+ {
+ "name": "named flag at start",
+ "input": "/flag --woof=yes mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "boolean flag without value",
+ "input": "/flag --woof mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "user id flag without value",
+ "input": "/flag --user --woof mrrp",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle",
+ "input": "/flag mrrp --woof=yes @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle with different value",
+ "input": "/flag mrrp --woof=no @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named",
+ "input": "/flag --woof=no --meow=mrrp --user=@user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named with quotes",
+ "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"",
+ "output": {
+ "meow": "meow meow mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "invalid value for named parameter",
+ "input": "/flag --user=meowings mrrp --woof",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json
new file mode 100644
index 00000000..1351c292
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_id_or_alias.json
@@ -0,0 +1,85 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "room",
+ "schema": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "room": "#test:matrix.org"
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link with url encoding",
+ "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net",
+ "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!#test/room\nversion 11, with @🐈️:maunium.net",
+ "via": [
+ "maunium.net"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json
new file mode 100644
index 00000000..aa266054
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_reference_list.json
@@ -0,0 +1,106 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "rooms",
+ "schema": {
+ "schema_type": "array",
+ "items": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "rooms": [
+ "#test:matrix.org"
+ ]
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "two room ids",
+ "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"
+ },
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI and matrix.to URL",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ },
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json
new file mode 100644
index 00000000..94667323
--- /dev/null
+++ b/event/cmdschema/testdata/commands/simple.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test simple",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "success",
+ "input": "/test simple mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "directed success",
+ "input": "/test simple@testbot mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "missing parameter",
+ "input": "/test simple",
+ "error": true,
+ "output": {
+ "meow": ""
+ }
+ },
+ {
+ "name": "directed at another bot",
+ "input": "/test simple@anotherbot mrrp",
+ "error": false,
+ "output": null
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json
new file mode 100644
index 00000000..9782f8ec
--- /dev/null
+++ b/event/cmdschema/testdata/commands/tail.json
@@ -0,0 +1,60 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "tail",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "reason",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true
+ }
+ ],
+ "fi.mau.tail_parameter": "reason"
+ },
+ "tests": [
+ {
+ "name": "no tail or flag",
+ "input": "/tail mrrp",
+ "output": {
+ "meow": "mrrp",
+ "reason": ""
+ }
+ },
+ {
+ "name": "tail, no flag",
+ "input": "/tail mrrp meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow"
+ }
+ },
+ {
+ "name": "flag before tail",
+ "input": "/tail mrrp --woof meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow",
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go
new file mode 100644
index 00000000..eceea3d2
--- /dev/null
+++ b/event/cmdschema/testdata/data.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package testdata
+
+import (
+ "embed"
+)
+
+//go:embed *
+var FS embed.FS
diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json
new file mode 100644
index 00000000..8f52b7f5
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.json
@@ -0,0 +1,30 @@
+[
+ {"name": "empty string", "input": "", "output": ["", "", false]},
+ {"name": "single word", "input": "meow", "output": ["meow", "", false]},
+ {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]},
+ {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]},
+ {"name": "only spaces", "input": " ", "output": ["", "", false]},
+ {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]},
+ {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]},
+ {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]},
+ {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]},
+ {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]},
+ {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]},
+ {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]},
+ {"name": "just opening quote", "input": "\"", "output": ["", "", true]},
+ {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]},
+ {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]},
+ {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]},
+ {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]},
+ {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]},
+ {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]},
+ {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]},
+ {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]},
+ {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]},
+ {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]},
+ {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]}
+]
diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json
new file mode 100644
index 00000000..9f249116
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.schema.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "parse_quote.schema.json",
+ "title": "parseQuote test cases",
+ "description": "Test cases for the parseQuoted function",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": [
+ "name",
+ "input",
+ "output"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "Name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "Input string to be parsed"
+ },
+ "output": {
+ "type": "array",
+ "description": "Expected output of parsing: [first word, remaining text, was quoted]",
+ "minItems": 3,
+ "maxItems": 3,
+ "prefixItems": [
+ {
+ "type": "string",
+ "description": "First parsed word"
+ },
+ {
+ "type": "string",
+ "description": "Remaining text after the first word"
+ },
+ {
+ "type": "boolean",
+ "description": "Whether the first word was quoted"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false
+ }
+}
diff --git a/event/content.go b/event/content.go
index b56e35f2..814aeec4 100644
--- a/event/content.go
+++ b/event/content.go
@@ -18,6 +18,7 @@ import (
// This is used by Content.ParseRaw() for creating the correct type of struct.
var TypeMap = map[Type]reflect.Type{
StateMember: reflect.TypeOf(MemberEventContent{}),
+ StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}),
StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}),
StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}),
StateRoomName: reflect.TypeOf(RoomNameEventContent{}),
@@ -38,7 +39,9 @@ var TypeMap = map[Type]reflect.Type{
StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}),
StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}),
StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}),
- StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}),
+
+ StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+ StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}),
StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}),
@@ -49,6 +52,7 @@ var TypeMap = map[Type]reflect.Type{
StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}),
StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}),
+ StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}),
EventMessage: reflect.TypeOf(MessageEventContent{}),
EventSticker: reflect.TypeOf(MessageEventContent{}),
@@ -59,8 +63,11 @@ var TypeMap = map[Type]reflect.Type{
EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}),
EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}),
- BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
- BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
+ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
+ BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
+ BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}),
+ BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}),
AccountDataRoomTags: reflect.TypeOf(TagEventContent{}),
AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}),
@@ -69,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{
AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}),
AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}),
- EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
- EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
- EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
+ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
+ EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}),
+ BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}),
InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
diff --git a/event/delayed.go b/event/delayed.go
new file mode 100644
index 00000000..fefb62af
--- /dev/null
+++ b/event/delayed.go
@@ -0,0 +1,70 @@
+package event
+
+import (
+ "encoding/json"
+
+ "go.mau.fi/util/jsontime"
+
+ "maunium.net/go/mautrix/id"
+)
+
+type ScheduledDelayedEvent struct {
+ DelayID id.DelayID `json:"delay_id"`
+ RoomID id.RoomID `json:"room_id"`
+ Type Type `json:"type"`
+ StateKey *string `json:"state_key,omitempty"`
+ Delay int64 `json:"delay"`
+ RunningSince jsontime.UnixMilli `json:"running_since"`
+ Content Content `json:"content"`
+}
+
+func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) {
+ evt := &Event{
+ ID: eventID,
+ RoomID: e.RoomID,
+ Type: e.Type,
+ StateKey: e.StateKey,
+ Content: e.Content,
+ Timestamp: ts.UnixMilli(),
+ }
+ return evt, evt.Content.ParseRaw(evt.Type)
+}
+
+type FinalisedDelayedEvent struct {
+ DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"`
+ Outcome DelayOutcome `json:"outcome"`
+ Reason DelayReason `json:"reason"`
+ Error json.RawMessage `json:"error,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+ Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+}
+
+type DelayStatus string
+
+var (
+ DelayStatusScheduled DelayStatus = "scheduled"
+ DelayStatusFinalised DelayStatus = "finalised"
+)
+
+type DelayAction string
+
+var (
+ DelayActionSend DelayAction = "send"
+ DelayActionCancel DelayAction = "cancel"
+ DelayActionRestart DelayAction = "restart"
+)
+
+type DelayOutcome string
+
+var (
+ DelayOutcomeSend DelayOutcome = "send"
+ DelayOutcomeCancel DelayOutcome = "cancel"
+)
+
+type DelayReason string
+
+var (
+ DelayReasonAction DelayReason = "action"
+ DelayReasonError DelayReason = "error"
+ DelayReasonDelay DelayReason = "delay"
+)
diff --git a/event/encryption.go b/event/encryption.go
index cf9c2814..c60cb91a 100644
--- a/event/encryption.go
+++ b/event/encryption.go
@@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
- return id.InputNotJSONString
+ return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString)
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}
@@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct {
type RequestedKeyInfo struct {
Algorithm id.Algorithm `json:"algorithm"`
RoomID id.RoomID `json:"room_id"`
- SenderKey id.SenderKey `json:"sender_key"`
SessionID id.SessionID `json:"session_id"`
+ // Deprecated: Matrix v1.3
+ SenderKey id.SenderKey `json:"sender_key"`
}
type RoomKeyWithheldCode string
diff --git a/event/events.go b/event/events.go
index a763cc31..72c1e161 100644
--- a/event/events.go
+++ b/event/events.go
@@ -130,36 +130,29 @@ func (evt *Event) GetStateKey() string {
return ""
}
-type StrippedState struct {
- Content Content `json:"content"`
- Type Type `json:"type"`
- StateKey string `json:"state_key"`
- Sender id.UserID `json:"sender"`
-}
-
type Unsigned struct {
- PrevContent *Content `json:"prev_content,omitempty"`
- PrevSender id.UserID `json:"prev_sender,omitempty"`
- Membership Membership `json:"membership,omitempty"`
- ReplacesState id.EventID `json:"replaces_state,omitempty"`
- Age int64 `json:"age,omitempty"`
- TransactionID string `json:"transaction_id,omitempty"`
- Relations *Relations `json:"m.relations,omitempty"`
- RedactedBecause *Event `json:"redacted_because,omitempty"`
- InviteRoomState []StrippedState `json:"invite_room_state,omitempty"`
+ PrevContent *Content `json:"prev_content,omitempty"`
+ PrevSender id.UserID `json:"prev_sender,omitempty"`
+ Membership Membership `json:"membership,omitempty"`
+ ReplacesState id.EventID `json:"replaces_state,omitempty"`
+ Age int64 `json:"age,omitempty"`
+ TransactionID string `json:"transaction_id,omitempty"`
+ Relations *Relations `json:"m.relations,omitempty"`
+ RedactedBecause *Event `json:"redacted_because,omitempty"`
+ InviteRoomState []*Event `json:"invite_room_state,omitempty"`
BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"`
BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"`
BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
- MauSoftFailed bool `json:"fi.mau.soft_failed,omitempty"`
- MauRejectionReason string `json:"fi.mau.rejection_reason,omitempty"`
+ ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"`
+ ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"`
}
func (us *Unsigned) IsEmpty() bool {
return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" &&
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() &&
- !us.MauSoftFailed && us.MauRejectionReason == ""
+ !us.ElementSoftFailed
}
diff --git a/event/member.go b/event/member.go
index 02b7cae9..9956a36b 100644
--- a/event/member.go
+++ b/event/member.go
@@ -7,8 +7,6 @@
package event
import (
- "encoding/json"
-
"maunium.net/go/mautrix/id"
)
@@ -35,22 +33,37 @@ const (
// MemberEventContent represents the content of a m.room.member state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroommember
type MemberEventContent struct {
- Membership Membership `json:"membership"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Displayname string `json:"displayname,omitempty"`
- IsDirect bool `json:"is_direct,omitempty"`
- ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
- Reason string `json:"reason,omitempty"`
- MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
+ Membership Membership `json:"membership"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Displayname string `json:"displayname,omitempty"`
+ IsDirect bool `json:"is_direct,omitempty"`
+ ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"`
+ MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
}
-type ThirdPartyInvite struct {
- DisplayName string `json:"display_name"`
- Signed struct {
- Token string `json:"token"`
- Signatures json.RawMessage `json:"signatures"`
- MXID string `json:"mxid"`
- }
+type SignedThirdPartyInvite struct {
+ Token string `json:"token"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
+ MXID string `json:"mxid"`
+}
+
+type ThirdPartyInvite struct {
+ DisplayName string `json:"display_name"`
+ Signed SignedThirdPartyInvite `json:"signed"`
+}
+
+type ThirdPartyInviteEventContent struct {
+ DisplayName string `json:"display_name"`
+ KeyValidityURL string `json:"key_validity_url"`
+ PublicKey id.Ed25519 `json:"public_key"`
+ PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"`
+}
+
+type ThirdPartyInviteKey struct {
+ KeyValidityURL string `json:"key_validity_url,omitempty"`
+ PublicKey id.Ed25519 `json:"public_key"`
}
diff --git a/event/message.go b/event/message.go
index 51403889..3fb3dc82 100644
--- a/event/message.go
+++ b/event/message.go
@@ -135,11 +135,16 @@ type MessageEventContent struct {
BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"`
+ BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"`
BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"`
+ BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"`
+
MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"`
MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"`
+
+ MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"`
}
func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType {
@@ -271,6 +276,25 @@ func (m *Mentions) Has(userID id.UserID) bool {
return m != nil && slices.Contains(m.UserIDs, userID)
}
+func (m *Mentions) Merge(other *Mentions) *Mentions {
+ if m == nil {
+ return other
+ } else if other == nil {
+ return m
+ }
+ return &Mentions{
+ UserIDs: slices.Concat(m.UserIDs, other.UserIDs),
+ Room: m.Room || other.Room,
+ }
+}
+
+type MSC4391BotCommandInputCustom[T any] struct {
+ Command string `json:"command"`
+ Arguments T `json:"arguments,omitempty"`
+}
+
+type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage]
+
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`
@@ -285,7 +309,8 @@ type FileInfo struct {
Blurhash string
AnoaBlurhash string
- MauGIF bool
+ MauGIF bool
+ IsAnimated bool
Width int
Height int
@@ -302,7 +327,8 @@ type serializableFileInfo struct {
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
- MauGIF bool `json:"fi.mau.gif,omitempty"`
+ MauGIF bool `json:"fi.mau.gif,omitempty"`
+ IsAnimated bool `json:"is_animated,omitempty"`
Width json.Number `json:"w,omitempty"`
Height json.Number `json:"h,omitempty"`
@@ -320,7 +346,8 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
ThumbnailFile: fileInfo.ThumbnailFile,
- MauGIF: fileInfo.MauGIF,
+ MauGIF: fileInfo.MauGIF,
+ IsAnimated: fileInfo.IsAnimated,
Blurhash: fileInfo.Blurhash,
AnoaBlurhash: fileInfo.AnoaBlurhash,
@@ -351,6 +378,7 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
ThumbnailURL: sfi.ThumbnailURL,
ThumbnailFile: sfi.ThumbnailFile,
MauGIF: sfi.MauGIF,
+ IsAnimated: sfi.IsAnimated,
Blurhash: sfi.Blurhash,
AnoaBlurhash: sfi.AnoaBlurhash,
}
diff --git a/event/message_test.go b/event/message_test.go
index 562a6622..c721df35 100644
--- a/event/message_test.go
+++ b/event/message_test.go
@@ -33,7 +33,7 @@ const invalidMessageEvent = `{
func TestMessageEventContent__ParseInvalid(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(invalidMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) {
assert.Equal(t, id.RoomID("!bar"), evt.RoomID)
err = evt.Content.ParseRaw(evt.Type)
- assert.NotNil(t, err)
+ assert.Error(t, err)
}
const messageEvent = `{
@@ -68,7 +68,7 @@ const messageEvent = `{
func TestMessageEventContent__ParseEdit(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(messageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -110,7 +110,7 @@ const imageMessageEvent = `{
func TestMessageEventContent__ParseMedia(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(imageMessageEvent), &evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) {
content := evt.Content.Parsed.(*event.MessageEventContent)
assert.Equal(t, event.MsgImage, content.MsgType)
parsedURL, err := content.URL.Parse()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL)
assert.Nil(t, content.NewContent)
assert.Equal(t, "image/png", content.GetInfo().MimeType)
@@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}`
func TestMessageEventContent__Marshal(t *testing.T) {
data, err := json.Marshal(parsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedMarshalResult, string(data))
}
@@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun
func TestMessageEventContent__Marshal_Custom(t *testing.T) {
data, err := json.Marshal(customParsedMessage)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, expectedCustomMarshalResult, string(data))
}
diff --git a/event/poll.go b/event/poll.go
index 47131a8f..9082f65e 100644
--- a/event/poll.go
+++ b/event/poll.go
@@ -35,7 +35,7 @@ type MSC1767Message struct {
}
type PollStartEventContent struct {
- RelatesTo *RelatesTo `json:"m.relates_to"`
+ RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
Mentions *Mentions `json:"m.mentions,omitempty"`
PollStart struct {
Kind string `json:"kind"`
diff --git a/event/powerlevels.go b/event/powerlevels.go
index 2f4d4573..668eb6d3 100644
--- a/event/powerlevels.go
+++ b/event/powerlevels.go
@@ -7,6 +7,8 @@
package event
import (
+ "math"
+ "slices"
"sync"
"go.mau.fi/util/ptr"
@@ -26,6 +28,9 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
+ beeperEphemeralLock sync.RWMutex
+ BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"`
+
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
@@ -34,6 +39,12 @@ type PowerLevelsEventContent struct {
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
+
+ BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"`
+
+ // This is not a part of power levels, it's added by mautrix-go internally in certain places
+ // in order to detect creator power accurately.
+ CreateEvent *Event `json:"-"`
}
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
@@ -45,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
UsersDefault: pl.UsersDefault,
Events: maps.Clone(pl.Events),
EventsDefault: pl.EventsDefault,
+ BeeperEphemeral: maps.Clone(pl.BeeperEphemeral),
StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
@@ -53,6 +65,10 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
KickPtr: ptr.Clone(pl.KickPtr),
BanPtr: ptr.Clone(pl.BanPtr),
RedactPtr: ptr.Clone(pl.RedactPtr),
+
+ BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr),
+
+ CreateEvent: pl.CreateEvent,
}
}
@@ -111,7 +127,17 @@ func (pl *PowerLevelsEventContent) StateDefault() int {
return 50
}
+func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int {
+ if pl.BeeperEphemeralDefaultPtr != nil {
+ return *pl.BeeperEphemeralDefaultPtr
+ }
+ return pl.EventsDefault
+}
+
func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
+ if pl.isCreator(userID) {
+ return math.MaxInt
+ }
pl.usersLock.RLock()
defer pl.usersLock.RUnlock()
level, ok := pl.Users[userID]
@@ -121,9 +147,19 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
return level
}
+const maxPL = 1<<53 - 1
+
func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) {
pl.usersLock.Lock()
defer pl.usersLock.Unlock()
+ if pl.isCreator(userID) {
+ return
+ }
+ if level == math.MaxInt && maxPL < math.MaxInt {
+ // Hack to avoid breaking on 32-bit systems (they're only slightly supported)
+ x := int64(maxPL)
+ level = int(x)
+ }
if level == pl.UsersDefault {
delete(pl.Users, userID)
} else {
@@ -138,9 +174,24 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int)
return pl.EnsureUserLevelAs("", target, level)
}
+func (pl *PowerLevelsEventContent) createContent() *CreateEventContent {
+ if pl.CreateEvent == nil {
+ return &CreateEventContent{}
+ }
+ return pl.CreateEvent.Content.AsCreate()
+}
+
+func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool {
+ cc := pl.createContent()
+ return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID))
+}
+
func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool {
+ if pl.isCreator(target) {
+ return false
+ }
existingLevel := pl.GetUserLevel(target)
- if actor != "" {
+ if actor != "" && !pl.isCreator(actor) {
actorLevel := pl.GetUserLevel(actor)
if actorLevel <= existingLevel || actorLevel < level {
return false
@@ -166,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int {
return level
}
+func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int {
+ pl.beeperEphemeralLock.RLock()
+ defer pl.beeperEphemeralLock.RUnlock()
+ level, ok := pl.BeeperEphemeral[eventType.String()]
+ if !ok {
+ return pl.BeeperEphemeralDefault()
+ }
+ return level
+}
+
+func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) {
+ pl.beeperEphemeralLock.Lock()
+ defer pl.beeperEphemeralLock.Unlock()
+ if level == pl.BeeperEphemeralDefault() {
+ delete(pl.BeeperEphemeral, eventType.String())
+ } else {
+ if pl.BeeperEphemeral == nil {
+ pl.BeeperEphemeral = make(map[string]int)
+ }
+ pl.BeeperEphemeral[eventType.String()] = level
+ }
+}
+
func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) {
pl.eventsLock.Lock()
defer pl.eventsLock.Unlock()
@@ -185,7 +259,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b
func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool {
existingLevel := pl.GetEventLevel(eventType)
- if actor != "" {
+ if actor != "" && !pl.isCreator(actor) {
actorLevel := pl.GetUserLevel(actor)
if existingLevel > actorLevel || level > actorLevel {
return false
diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go
new file mode 100644
index 00000000..f5861583
--- /dev/null
+++ b/event/powerlevels_ephemeral_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/event"
+)
+
+func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 45,
+ }
+
+ assert.Equal(t, 45, pl.BeeperEphemeralDefault())
+
+ override := 60
+ pl.BeeperEphemeralDefaultPtr = &override
+ assert.Equal(t, 60, pl.BeeperEphemeralDefault())
+}
+
+func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 25,
+ }
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+
+ assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType))
+
+ pl.SetBeeperEphemeralLevel(evtType, 50)
+ assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType))
+ require.NotNil(t, pl.BeeperEphemeral)
+ assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()])
+
+ pl.SetBeeperEphemeralLevel(evtType, 25)
+ _, exists := pl.BeeperEphemeral[evtType.String()]
+ assert.False(t, exists)
+}
+
+func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) {
+ override := 70
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 35,
+ BeeperEphemeral: map[string]int{"com.example.ephemeral": 90},
+ BeeperEphemeralDefaultPtr: &override,
+ }
+
+ cloned := pl.Clone()
+ require.NotNil(t, cloned)
+ require.NotNil(t, cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"])
+
+ cloned.BeeperEphemeral["com.example.ephemeral"] = 99
+ *cloned.BeeperEphemeralDefaultPtr = 71
+
+ assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"])
+ assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr)
+}
diff --git a/event/reply.go b/event/reply.go
index 9ae1c110..5f55bb80 100644
--- a/event/reply.go
+++ b/event/reply.go
@@ -32,12 +32,13 @@ func TrimReplyFallbackText(text string) string {
}
func (content *MessageEventContent) RemoveReplyFallback() {
- if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
- if content.Format == FormatHTML {
- content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML {
+ origHTML := content.FormattedBody
+ content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
+ if content.FormattedBody != origHTML {
+ content.Body = TrimReplyFallbackText(content.Body)
+ content.replyFallbackRemoved = true
}
- content.Body = TrimReplyFallbackText(content.Body)
- content.replyFallbackRemoved = true
}
}
diff --git a/event/state.go b/event/state.go
index 028691e1..ace170a5 100644
--- a/event/state.go
+++ b/event/state.go
@@ -8,6 +8,10 @@ package event
import (
"encoding/base64"
+ "encoding/json"
+ "slices"
+
+ "go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/id"
)
@@ -52,10 +56,40 @@ type TopicEventContent struct {
// m.room.topic state event as described in [MSC3765].
//
// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765
-type ExtensibleTopic struct {
+type ExtensibleTopic = ExtensibleTextContainer
+
+type ExtensibleTextContainer struct {
Text []ExtensibleText `json:"m.text"`
}
+func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool {
+ if c == nil || description == nil {
+ return c == description
+ }
+ return slices.Equal(c.Text, description.Text)
+}
+
+func MakeExtensibleText(text string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: text,
+ MimeType: "text/plain",
+ }},
+ }
+}
+
+func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer {
+ return &ExtensibleTextContainer{
+ Text: []ExtensibleText{{
+ Body: plaintext,
+ MimeType: "text/plain",
+ }, {
+ Body: html,
+ MimeType: "text/html",
+ }},
+ }
+}
+
// ExtensibleText represents the contents of an m.text field.
type ExtensibleText struct {
MimeType string `json:"mimetype,omitempty"`
@@ -69,39 +103,66 @@ type TombstoneEventContent struct {
ReplacementRoom id.RoomID `json:"replacement_room"`
}
+func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID {
+ if tec == nil {
+ return ""
+ }
+ return tec.ReplacementRoom
+}
+
type Predecessor struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
}
-type RoomVersion string
+// Deprecated: use id.RoomVersion instead
+type RoomVersion = id.RoomVersion
+// Deprecated: use id.RoomVX constants instead
const (
- RoomV1 RoomVersion = "1"
- RoomV2 RoomVersion = "2"
- RoomV3 RoomVersion = "3"
- RoomV4 RoomVersion = "4"
- RoomV5 RoomVersion = "5"
- RoomV6 RoomVersion = "6"
- RoomV7 RoomVersion = "7"
- RoomV8 RoomVersion = "8"
- RoomV9 RoomVersion = "9"
- RoomV10 RoomVersion = "10"
- RoomV11 RoomVersion = "11"
+ RoomV1 = id.RoomV1
+ RoomV2 = id.RoomV2
+ RoomV3 = id.RoomV3
+ RoomV4 = id.RoomV4
+ RoomV5 = id.RoomV5
+ RoomV6 = id.RoomV6
+ RoomV7 = id.RoomV7
+ RoomV8 = id.RoomV8
+ RoomV9 = id.RoomV9
+ RoomV10 = id.RoomV10
+ RoomV11 = id.RoomV11
+ RoomV12 = id.RoomV12
)
// CreateEventContent represents the content of a m.room.create state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomcreate
type CreateEventContent struct {
- Type RoomType `json:"type,omitempty"`
- Federate *bool `json:"m.federate,omitempty"`
- RoomVersion RoomVersion `json:"room_version,omitempty"`
- Predecessor *Predecessor `json:"predecessor,omitempty"`
+ Type RoomType `json:"type,omitempty"`
+ Federate *bool `json:"m.federate,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Predecessor *Predecessor `json:"predecessor,omitempty"`
+
+ // Room v12+ only
+ AdditionalCreators []id.UserID `json:"additional_creators,omitempty"`
// Deprecated: use the event sender instead
Creator id.UserID `json:"creator,omitempty"`
}
+func (cec *CreateEventContent) GetPredecessor() (p Predecessor) {
+ if cec != nil && cec.Predecessor != nil {
+ p = *cec.Predecessor
+ }
+ return
+}
+
+func (cec *CreateEventContent) SupportsCreatorPower() bool {
+ if cec == nil {
+ return false
+ }
+ return cec.RoomVersion.PrivilegedRoomCreators()
+}
+
// JoinRule specifies how open a room is to new members.
// https://spec.matrix.org/v1.2/client-server-api/#mroomjoin_rules
type JoinRule string
@@ -177,7 +238,8 @@ type BridgeInfoSection struct {
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
ExternalURL string `json:"external_url,omitempty"`
- Receiver string `json:"fi.mau.receiver,omitempty"`
+ Receiver string `json:"fi.mau.receiver,omitempty"`
+ MessageRequest bool `json:"com.beeper.message_request,omitempty"`
}
// BridgeEventContent represents the content of a m.bridge state event.
@@ -191,6 +253,32 @@ type BridgeEventContent struct {
BeeperRoomType string `json:"com.beeper.room_type,omitempty"`
BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"`
+
+ TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"`
+ TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"`
+}
+
+// DisappearingType represents the type of a disappearing message timer.
+type DisappearingType string
+
+const (
+ DisappearingTypeNone DisappearingType = ""
+ DisappearingTypeAfterRead DisappearingType = "after_read"
+ DisappearingTypeAfterSend DisappearingType = "after_send"
+)
+
+type BeeperDisappearingTimer struct {
+ Type DisappearingType `json:"type"`
+ Timer jsontime.Milliseconds `json:"timer"`
+}
+
+type marshalableBeeperDisappearingTimer BeeperDisappearingTimer
+
+func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) {
+ if bdt == nil || bdt.Type == DisappearingTypeNone {
+ return []byte("{}"), nil
+ }
+ return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt))
}
type SpaceChildEventContent struct {
@@ -244,12 +332,26 @@ func (mpc *ModPolicyContent) EntityOrHash() string {
return mpc.Entity
}
-// Deprecated: MSC2716 has been abandoned
-type InsertionMarkerContent struct {
- InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"`
- Timestamp int64 `json:"com.beeper.timestamp,omitempty"`
-}
-
type ElementFunctionalMembersContent struct {
ServiceMembers []id.UserID `json:"service_members"`
}
+
+func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool {
+ if slices.Contains(efmc.ServiceMembers, mxid) {
+ return false
+ }
+ efmc.ServiceMembers = append(efmc.ServiceMembers, mxid)
+ return true
+}
+
+type PolicyServerPublicKeys struct {
+ Ed25519 id.Ed25519 `json:"ed25519,omitempty"`
+}
+
+type RoomPolicyEventContent struct {
+ Via string `json:"via,omitempty"`
+ PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"`
+
+ // Deprecated, only for legacy use
+ PublicKey id.Ed25519 `json:"public_key,omitempty"`
+}
diff --git a/event/type.go b/event/type.go
index 591d598d..80b86728 100644
--- a/event/type.go
+++ b/event/type.go
@@ -108,13 +108,14 @@ func (et *Type) IsCustom() bool {
func (et *Type) GuessClass() TypeClass {
switch et.Type {
- case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type,
+ case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type,
StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type,
StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type,
StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type,
- StateInsertionMarker.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type:
+ StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type,
+ StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type:
return StateEventType
- case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type:
+ case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type:
return EphemeralEventType
case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type,
AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type,
@@ -127,7 +128,7 @@ func (et *Type) GuessClass() TypeClass {
InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type,
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type,
- BeeperTranscription.Type:
+ EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
@@ -177,6 +178,7 @@ var (
StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType}
StateGuestAccess = Type{"m.room.guest_access", StateEventType}
StateMember = Type{"m.room.member", StateEventType}
+ StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType}
StatePowerLevels = Type{"m.room.power_levels", StateEventType}
StateRoomName = Type{"m.room.name", StateEventType}
StateTopic = Type{"m.room.topic", StateEventType}
@@ -193,6 +195,9 @@ var (
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
+ StateRoomPolicy = Type{"m.room.policy", StateEventType}
+ StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType}
+
StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType}
StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType}
StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType}
@@ -200,11 +205,10 @@ var (
StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType}
StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType}
- // Deprecated: MSC2716 has been abandoned
- StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
-
StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType}
StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType}
+ StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType}
+ StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType}
)
// Message events
@@ -233,18 +237,24 @@ var (
CallNegotiate = Type{"m.call.negotiate", MessageEventType}
CallHangup = Type{"m.call.hangup", MessageEventType}
- BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
- BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
+ BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
+ BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
+ BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType}
+ BeeperSendState = Type{"com.beeper.send_state", MessageEventType}
EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType}
EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType}
+ EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType}
)
// Ephemeral events
var (
- EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
- EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
- EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
+ EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
+ EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType}
+ BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType}
)
// Account data events
diff --git a/example/main.go b/example/main.go
index d8006d46..2bf4bef3 100644
--- a/example/main.go
+++ b/example/main.go
@@ -143,7 +143,7 @@ func main() {
if err != nil {
log.Error().Err(err).Msg("Failed to send event")
} else {
- log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent")
+ log.Info().Stringer("event_id", resp.EventID).Msg("Event sent")
}
}
cancelSync()
diff --git a/federation/client.go b/federation/client.go
index 7c460d44..183fb5d1 100644
--- a/federation/client.go
+++ b/federation/client.go
@@ -21,6 +21,7 @@ import (
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -29,6 +30,8 @@ type Client struct {
ServerName string
UserAgent string
Key *SigningKey
+
+ ResponseSizeLimit int64
}
func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
@@ -36,10 +39,16 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien
HTTP: &http.Client{
Transport: NewServerResolvingTransport(cache),
Timeout: 120 * time.Second,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ // Federation requests do not allow redirects.
+ return http.ErrUseLastResponse
+ },
},
UserAgent: mautrix.DefaultUserAgent,
ServerName: serverName,
Key: key,
+
+ ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
}
}
@@ -80,7 +89,7 @@ type RespSendTransaction struct {
}
func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
- err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp)
+ err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp)
return
}
@@ -254,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken
return
}
+type ReqMakeJoin struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+ SupportedVersions []id.RoomVersion
+}
+
+type RespMakeJoin struct {
+ RoomVersion id.RoomVersion `json:"room_version"`
+ Event PDU `json:"event"`
+}
+
+type ReqSendJoin struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ OmitMembers bool
+ Event PDU
+ Via string
+}
+
+type ReqSendKnock struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type RespSendJoin struct {
+ AuthChain []PDU `json:"auth_chain"`
+ Event PDU `json:"event"`
+ MembersOmitted bool `json:"members_omitted"`
+ ServersInRoom []string `json:"servers_in_room"`
+ State []PDU `json:"state"`
+}
+
+type RespSendKnock struct {
+ KnockRoomState []PDU `json:"knock_room_state"`
+}
+
+type ReqSendInvite struct {
+ RoomID id.RoomID `json:"-"`
+ UserID id.UserID `json:"-"`
+ Event PDU `json:"event"`
+ InviteRoomState []PDU `json:"invite_room_state"`
+ RoomVersion id.RoomVersion `json:"room_version"`
+}
+
+type RespSendInvite struct {
+ Event PDU `json:"event"`
+}
+
+type ReqMakeLeave struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+}
+
+type ReqSendLeave struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type (
+ ReqMakeKnock = ReqMakeJoin
+ RespMakeKnock = RespMakeJoin
+ RespMakeLeave = RespMakeJoin
+)
+
+func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_join", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_join", req.RoomID, req.EventID},
+ Query: url.Values{
+ "omit_members": {strconv.FormatBool(req.OmitMembers)},
+ },
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.UserID.Homeserver(),
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "invite", req.RoomID, req.UserID},
+ Authenticate: true,
+ RequestJSON: req,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ })
+ return
+}
+
type URLPath []any
func (fup URLPath) FullPath() []any {
@@ -305,15 +477,27 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
WrappedError: err,
}
}
- defer func() {
- _ = resp.Body.Close()
- }()
+ if !params.DontReadBody {
+ defer resp.Body.Close()
+ }
var body []byte
- if resp.StatusCode >= 400 {
+ if resp.StatusCode >= 300 {
body, err = mautrix.ParseErrorResponse(req, resp)
return body, resp, err
} else if params.ResponseJSON != nil || !params.DontReadBody {
- body, err = io.ReadAll(resp.Body)
+ if resp.ContentLength > c.ResponseSizeLimit {
+ return body, resp, mautrix.HTTPError{
+ Request: req,
+ Response: resp,
+
+ Message: "not reading response",
+ WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
+ }
+ }
+ body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
+ if err == nil && len(body) > int(c.ResponseSizeLimit) {
+ err = mautrix.ErrBodyReadReachedLimit
+ }
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
@@ -404,7 +588,7 @@ func (r *signableRequest) Verify(key id.SigningKey, sig string) error {
if err != nil {
return fmt.Errorf("failed to marshal data: %w", err)
}
- return VerifyJSONRaw(key, sig, message)
+ return signutil.VerifyJSONRaw(key, sig, message)
}
func (r *signableRequest) Sign(key *SigningKey) (string, error) {
diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go
new file mode 100644
index 00000000..c72933c2
--- /dev/null
+++ b/federation/eventauth/eventauth.go
@@ -0,0 +1,851 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "encoding/json"
+ "encoding/json/jsontext"
+ "errors"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/exstrings"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type AuthFailError struct {
+ Index string
+ Message string
+ Wrapped error
+}
+
+func (afe AuthFailError) Error() string {
+ if afe.Message != "" {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message)
+ } else if afe.Wrapped != nil {
+ return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error())
+ }
+ return fmt.Sprintf("fail %s", afe.Index)
+}
+
+func (afe AuthFailError) Unwrap() error {
+ return afe.Wrapped
+}
+
+var mFederatePath = exgjson.Path("m.federate")
+
+var (
+ ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"}
+ ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"}
+ ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"}
+ ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion}
+ ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"}
+ ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"}
+
+ ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"}
+ ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"}
+ ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"}
+ ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"}
+
+ ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"}
+ ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"}
+ ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"}
+ ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"}
+ ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"}
+ ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"}
+ ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"}
+ ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"}
+ ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"}
+
+ ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"}
+
+ ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"}
+ ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"}
+ ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"}
+ ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"}
+ ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"}
+ ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"}
+ ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"}
+ ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"}
+ ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"}
+ ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"}
+ ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"}
+ ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"}
+ ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"}
+ ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"}
+ ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"}
+ ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"}
+ ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"}
+ ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"}
+ ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"}
+ ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"}
+ ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"}
+ ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"}
+ ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"}
+ ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"}
+ ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"}
+ ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"}
+ ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"}
+ ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"}
+
+ ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"}
+
+ ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"}
+
+ ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"}
+
+ ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"}
+
+ ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"}
+ ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"}
+ ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"}
+ ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"}
+ ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"}
+ ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"}
+ ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"}
+)
+
+func isRejected(evt *pdu.PDU) bool {
+ return evt.InternalMeta.Rejected
+}
+
+type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error)
+
+func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error {
+ if evt.Type == event.StateCreate.Type {
+ // 1. If type is m.room.create:
+ return authorizeCreate(roomVersion, evt)
+ }
+ var createEvt *pdu.PDU
+ if roomVersion.RoomIDIsCreateEventID() {
+ // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event,
+ // with the sigil ! instead of $, reject.
+ if len(evt.RoomID) != 44 {
+ return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID))
+ } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err)
+ } else if len(createEvts) != 1 {
+ return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID)
+ } else if isRejected(createEvts[0]) {
+ return ErrRejectedCreateEvent
+ } else {
+ createEvt = createEvts[0]
+ }
+ }
+ authEvents, err := getEvents(evt.AuthEvents)
+ if err != nil {
+ return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err)
+ }
+ expectedAuthEvents := evt.AuthEventSelection(roomVersion)
+ deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents))
+ // 3. Considering the event’s auth_events:
+ for i, ae := range authEvents {
+ authEvtID := evt.AuthEvents[i]
+ if ae == nil {
+ return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID)
+ } else if ae.StateKey == nil {
+ // This approximately falls under rule 3.2.
+ return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID)
+ }
+ key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey}
+ if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound {
+ // 3.1. If there are duplicate entries for a given type and state_key pair, reject.
+ return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID)
+ } else if !expectedAuthEvents.Has(key) {
+ // 3.2. If there are entries whose type and state_key don’t match those specified by
+ // the auth events selection algorithm described in the server specification, reject.
+ return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey)
+ } else if isRejected(ae) {
+ // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject.
+ return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID)
+ } else if ae.RoomID != evt.RoomID {
+ // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject.
+ return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID)
+ } else {
+ deduplicator[key] = authEvtID
+ }
+ if ae.Type == event.StateCreate.Type {
+ if createEvt == nil {
+ createEvt = ae
+ } else {
+ // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+
+ panic(fmt.Errorf("impossible case: multiple create events found in auth events"))
+ }
+ }
+ }
+ if createEvt == nil {
+ // This comes either from auth_events or room_id depending on the room version.
+ // The checks above make sure it's from the right source.
+ return ErrNoCreateEvent
+ }
+ if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() {
+ // 4. If the content of the m.room.create event in the room state has the property m.federate set to false,
+ // and the sender domain of the event does not match the sender domain of the create event, reject.
+ return ErrFederationDisabled
+ }
+ if evt.Type == event.StateMember.Type {
+ // 5. If type is m.room.member:
+ return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey)
+ }
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ if senderMembership != event.MembershipJoin {
+ // 6. If the sender’s current membership state is not join, reject.
+ return ErrNotInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderPL := powerLevels.GetUserLevel(evt.Sender)
+ if evt.Type == event.StateThirdPartyInvite.Type {
+ // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level.
+ if senderPL >= powerLevels.Invite() {
+ return nil
+ }
+ return ErrInsufficientPowerForThirdPartyInvite
+ }
+ typeClass := event.MessageEventType
+ if evt.StateKey != nil {
+ typeClass = event.StateEventType
+ }
+ evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass})
+ if evtLevel > senderPL {
+ // 8. If the event type’s required power level is greater than the sender’s power level, reject.
+ return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL)
+ }
+
+ if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() {
+ // 9. If the event has a state_key that starts with an @ and does not match the sender, reject.
+ return ErrMismatchingPrivateStateKey
+ }
+
+ if evt.Type == event.StatePowerLevels.Type {
+ // 10. If type is m.room.power_levels:
+ return authorizePowerLevels(roomVersion, evt, createEvt, authEvents)
+ }
+
+ // 11. Otherwise, allow.
+ return nil
+}
+
+var ErrUserIDNotAString = errors.New("not a string")
+var ErrUserIDNotValid = errors.New("not a valid user ID")
+
+func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error {
+ if userID.Type != gjson.String {
+ return ErrUserIDNotAString
+ }
+ // In a future room version, user IDs will have stricter validation
+ _, _, err := id.UserID(userID.Str).Parse()
+ if err != nil {
+ return ErrUserIDNotValid
+ }
+ return nil
+}
+
+func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error {
+ if len(evt.PrevEvents) > 0 {
+ // 1.1. If it has any prev_events, reject.
+ return ErrCreateHasPrevEvents
+ }
+ if roomVersion.RoomIDIsCreateEventID() {
+ if evt.RoomID != "" {
+ // 1.2. If the event has a room_id, reject.
+ return ErrCreateHasRoomID
+ }
+ } else {
+ _, _, server := id.ParseCommonIdentifier(evt.RoomID)
+ if server == "" || server != evt.Sender.Homeserver() {
+ // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject.
+ return ErrRoomIDDoesntMatchSender
+ }
+ }
+ if !roomVersion.IsKnown() {
+ // 1.3. If content.room_version is present and is not a recognised version, reject.
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ additionalCreators := gjson.GetBytes(evt.Content, "additional_creators")
+ if additionalCreators.Exists() {
+ if !additionalCreators.IsArray() {
+ return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators)
+ }
+ for i, item := range additionalCreators.Array() {
+ // 1.4. If additional_creators is present in content and is not an array of strings
+ // where each string passes the same user ID validation applied to sender, reject.
+ if err := isValidUserID(roomVersion, item); err != nil {
+ return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err)
+ }
+ }
+ }
+ }
+ if roomVersion.CreatorInContent() {
+ // 1.4. (v10 and below) If content has no creator property, reject.
+ if !gjson.GetBytes(evt.Content, "creator").Exists() {
+ return ErrMissingCreator
+ }
+ }
+ // 1.5. Otherwise, allow.
+ return nil
+}
+
+func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error {
+ membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str)
+ if evt.StateKey == nil {
+ // 5.1. If there is no state_key property, or no membership property in content, reject.
+ return ErrMemberNotState
+ }
+ authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str)
+ if authorizedVia != "" {
+ homeserver := authorizedVia.Homeserver()
+ err := evt.VerifySignature(roomVersion, homeserver, getKey)
+ if err != nil {
+ // 5.2. If content has a join_authorised_via_users_server key:
+ // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject.
+ return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err)
+ }
+ }
+ targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave"))
+ senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
+ switch membership {
+ case event.MembershipJoin:
+ createEvtID, err := createEvt.GetEventID(roomVersion)
+ if err != nil {
+ return fmt.Errorf("failed to get create event ID: %w", err)
+ }
+ creator := createEvt.Sender.String()
+ if roomVersion.CreatorInContent() {
+ creator = gjson.GetBytes(evt.Content, "creator").Str
+ }
+ if len(evt.PrevEvents) == 1 &&
+ len(evt.AuthEvents) <= 1 &&
+ evt.PrevEvents[0] == createEvtID &&
+ *evt.StateKey == creator {
+ // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow.
+ return nil
+ }
+ // Spec wart: this would make more sense before the check above.
+ // Now you can set anyone as the sender of the first join.
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.3.2. If the sender does not match state_key, reject.
+ return ErrCantJoinOtherUser
+ }
+
+ if senderMembership == event.MembershipBan {
+ // 5.3.3. If the sender is banned, reject.
+ return ErrCantJoinBanned
+ }
+
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ switch joinRule {
+ case event.JoinRuleKnock:
+ if !roomVersion.Knocks() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleInvite:
+ // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join.
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ return nil
+ }
+ return ErrCantJoinWithoutInvite
+ case event.JoinRuleKnockRestricted:
+ if !roomVersion.KnockRestricted() {
+ return ErrInvalidJoinRule
+ }
+ fallthrough
+ case event.JoinRuleRestricted:
+ if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() {
+ return ErrInvalidJoinRule
+ }
+ if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
+ // 5.3.5.1. If membership state is join or invite, allow.
+ return nil
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() {
+ // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject.
+ return ErrAuthoriserCantInvite
+ }
+ authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave)))
+ if authorizerMembership != event.MembershipJoin {
+ return ErrAuthoriserNotInRoom
+ }
+ // 5.3.5.3. Otherwise, allow.
+ return nil
+ case event.JoinRulePublic:
+ // 5.3.6. If the join_rule is public, allow.
+ return nil
+ default:
+ // 5.3.7. Otherwise, reject.
+ return ErrInvalidJoinRule
+ }
+ case event.MembershipInvite:
+ tpiVal := gjson.GetBytes(evt.Content, "third_party_invite")
+ if tpiVal.Exists() {
+ if targetPrevMembership == event.MembershipBan {
+ return ErrThirdPartyInviteBanned
+ }
+ signed := tpiVal.Get("signed")
+ mxid := signed.Get("mxid").Str
+ token := signed.Get("token").Str
+ if mxid == "" || token == "" {
+ // 5.4.1.2. If content.third_party_invite does not have a signed property, reject.
+ // 5.4.1.3. If signed does not have mxid and token properties, reject.
+ return ErrThirdPartyInviteMissingFields
+ }
+ if mxid != *evt.StateKey {
+ // 5.4.1.4. If mxid does not match state_key, reject.
+ return ErrThirdPartyInviteMXIDMismatch
+ }
+ tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token)
+ if tpiEvt == nil {
+ // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject.
+ return ErrThirdPartyInviteNotFound
+ }
+ if tpiEvt.Sender != evt.Sender {
+ // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject.
+ return ErrThirdPartyInviteSenderMismatch
+ }
+ var keys []id.Ed25519
+ const ed25519Base64Len = 43
+ oldPubKey := gjson.GetBytes(evt.Content, "public_key.token")
+ if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(oldPubKey.Str))
+ }
+ gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.Number {
+ return false
+ }
+ if value.Type == gjson.String && len(value.Str) == ed25519Base64Len {
+ keys = append(keys, id.Ed25519(value.Str))
+ }
+ return true
+ })
+ rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str))
+ var validated bool
+ for _, key := range keys {
+ if signutil.VerifyJSONAny(key, rawSigned) == nil {
+ validated = true
+ }
+ }
+ if validated {
+ // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow.
+ return nil
+ }
+ // 4.4.1.8. Otherwise, reject.
+ return ErrThirdPartyInviteNotSigned
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.4.2. If the sender’s current membership state is not join, reject.
+ return ErrInviterNotInRoom
+ }
+ // 5.4.3. If target user’s current membership state is join or ban, reject.
+ if targetPrevMembership == event.MembershipJoin {
+ return ErrInviteTargetAlreadyInRoom
+ } else if targetPrevMembership == event.MembershipBan {
+ return ErrInviteTargetBanned
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() {
+ // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow.
+ return nil
+ }
+ // 5.4.5. Otherwise, reject.
+ return ErrInsufficientPermissionForInvite
+ case event.MembershipLeave:
+ if evt.Sender.String() == *evt.StateKey {
+ // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock.
+ if senderMembership == event.MembershipInvite ||
+ senderMembership == event.MembershipJoin ||
+ (senderMembership == event.MembershipKnock && roomVersion.Knocks()) {
+ return nil
+ }
+ return ErrCantLeaveWithoutBeingInRoom
+ }
+ if senderMembership != event.MembershipJoin {
+ // 5.5.2. If the sender’s current membership state is not join, reject.
+ return ErrCantKickWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() {
+ // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject.
+ return ErrInsufficientPermissionForUnban
+ }
+ if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // TODO separate errors for < kick and < target user level?
+ // 5.5.5. Otherwise, reject.
+ return ErrInsufficientPermissionForKick
+ case event.MembershipBan:
+ if senderMembership != event.MembershipJoin {
+ // 5.6.1. If the sender’s current membership state is not join, reject.
+ return ErrCantBanWithoutBeingInRoom
+ }
+ powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
+ if err != nil {
+ return err
+ }
+ senderLevel := powerLevels.GetUserLevel(evt.Sender)
+ if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
+ // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow.
+ return nil
+ }
+ // 5.6.3. Otherwise, reject.
+ return ErrInsufficientPermissionForBan
+ case event.MembershipKnock:
+ joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
+ validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock
+ validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted
+ if !validKnockRule && !validKnockRestrictedRule {
+ // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject.
+ return ErrNotKnockableRoom
+ }
+ if evt.Sender.String() != *evt.StateKey {
+ // 5.7.2. If the sender does not match state_key, reject.
+ return ErrCantKnockOtherUser
+ }
+ if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin {
+ // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow.
+ return nil
+ }
+ // 5.7.4. Otherwise, reject.
+ return ErrCantKnockWhileInRoom
+ default:
+ // 5.8. Otherwise, the membership is unknown. Reject.
+ return ErrUnknownMembership
+ }
+}
+
+func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error {
+ if roomVersion.ValidatePowerLevelInts() {
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ res := gjson.GetBytes(evt.Content, key)
+ if !res.Exists() {
+ continue
+ }
+ if parseIntWithVersion(roomVersion, res) == nil {
+ // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject.
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ }
+ for _, key := range []string{"events", "notifications"} {
+ obj := gjson.GetBytes(evt.Content, key)
+ if !obj.Exists() {
+ continue
+ }
+ // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject.
+ if !obj.IsObject() {
+ return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
+ }
+ var err error
+ // 10.2. [...] are not an object with values that are integers, reject.
+ obj.ForEach(func(innerKey, value gjson.Result) bool {
+ if parseIntWithVersion(roomVersion, value) == nil {
+ err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ }
+ var creators []id.UserID
+ if roomVersion.PrivilegedRoomCreators() {
+ creators = append(creators, createEvt.Sender)
+ gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool {
+ creators = append(creators, id.UserID(value.Str))
+ return true
+ })
+ }
+ users := gjson.GetBytes(evt.Content, "users")
+ if users.Exists() {
+ if !users.IsObject() {
+ // 10.3. If the users property in content is not an object [...], reject.
+ return fmt.Errorf("%w users", ErrTopLevelPLNotInteger)
+ }
+ var err error
+ users.ForEach(func(key, value gjson.Result) bool {
+ if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil {
+ // 10.3. [...] is not an object with keys that are valid user IDs [...], reject.
+ err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr)
+ return false
+ }
+ if parseIntWithVersion(roomVersion, value) == nil {
+ // 10.3. [...] is not an object [...] with values that are integers, reject.
+ err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str)
+ return false
+ }
+ // creators is only filled if the room version has privileged room creators
+ if slices.Contains(creators, id.UserID(key.Str)) {
+ // 10.4. If the users property in content contains the sender of the m.room.create event or any of
+ // the additional_creators array (if present) from the content of the m.room.create event, reject.
+ err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str)
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ }
+ oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "")
+ if oldPL == nil {
+ // 10.5. If there is no previous m.room.power_levels event in the room, allow.
+ return nil
+ }
+ if slices.Contains(creators, evt.Sender) {
+ // Skip remaining checks for creators
+ return nil
+ }
+ senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String())))
+ if senderPLPtr == nil {
+ senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default"))
+ if senderPLPtr == nil {
+ senderPLPtr = ptr.Ptr(0)
+ }
+ }
+ for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
+ oldVal := gjson.GetBytes(oldPL.Content, key)
+ newVal := gjson.GetBytes(evt.Content, key)
+ if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil {
+ return err
+ }
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "events", "",
+ gjson.GetBytes(oldPL.Content, "events"),
+ gjson.GetBytes(evt.Content, "events"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "notifications", "",
+ gjson.GetBytes(oldPL.Content, "notifications"),
+ gjson.GetBytes(evt.Content, "notifications"),
+ ); err != nil {
+ return err
+ }
+ if err := allowPowerChangeMap(
+ roomVersion, *senderPLPtr, "users", evt.Sender.String(),
+ gjson.GetBytes(oldPL.Content, "users"),
+ gjson.GetBytes(evt.Content, "users"),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) {
+ old.ForEach(func(key, value gjson.Result) bool {
+ newVal := new.Get(exgjson.Path(key.Str))
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal)
+ if err == nil && ownID != "" && key.Str != ownID {
+ parsedOldVal := parseIntWithVersion(roomVersion, value)
+ parsedNewVal := parseIntWithVersion(roomVersion, newVal)
+ if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal {
+ err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal)
+ }
+ }
+ return err == nil
+ })
+ if err != nil {
+ return
+ }
+ new.ForEach(func(key, value gjson.Result) bool {
+ err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value)
+ return err == nil
+ })
+ return
+}
+
+func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error {
+ oldVal := parseIntWithVersion(roomVersion, old)
+ newVal := parseIntWithVersion(roomVersion, new)
+ if oldVal == nil {
+ if newVal == nil || *newVal <= maxVal {
+ return nil
+ }
+ } else if newVal == nil {
+ if *oldVal <= maxVal {
+ return nil
+ }
+ } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) {
+ return nil
+ }
+ return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal)
+}
+
+func stringifyForError(val gjson.Result) string {
+ if !val.Exists() {
+ return "null"
+ }
+ return val.Raw
+}
+
+func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU {
+ for _, evt := range events {
+ if evt.Type == evtType && *evt.StateKey == stateKey {
+ return evt
+ }
+ }
+ return nil
+}
+
+func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T {
+ return reader(findEvent(events, evtType, stateKey))
+}
+
+func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string {
+ return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string {
+ if evt == nil {
+ return defVal
+ }
+ res := gjson.GetBytes(evt.Content, fieldPath)
+ if res.Type != gjson.String {
+ return defVal
+ }
+ return res.Str
+ })
+}
+
+func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) {
+ var err error
+ powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent {
+ if evt == nil {
+ return nil
+ }
+ content := evt.Content
+ out := &event.PowerLevelsEventContent{}
+ if !roomVersion.ValidatePowerLevelInts() {
+ safeParsePowerLevels(content, out)
+ } else {
+ err = json.Unmarshal(content, out)
+ }
+ return out
+ })
+ if err != nil {
+ // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ if roomVersion.PrivilegedRoomCreators() {
+ if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{}
+ }
+ powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
+ }
+ } else if powerLevels == nil {
+ powerLevels = &event.PowerLevelsEventContent{
+ Users: map[id.UserID]int{
+ createEvt.Sender: 100,
+ },
+ }
+ }
+ return powerLevels, nil
+}
+
+func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int {
+ if roomVersion.ValidatePowerLevelInts() {
+ if val.Type != gjson.Number {
+ return nil
+ }
+ return ptr.Ptr(int(val.Int()))
+ }
+ return parsePythonInt(val)
+}
+
+func parsePythonInt(val gjson.Result) *int {
+ switch val.Type {
+ case gjson.True:
+ return ptr.Ptr(1)
+ case gjson.False:
+ return ptr.Ptr(0)
+ case gjson.Number:
+ return ptr.Ptr(int(val.Int()))
+ case gjson.String:
+ // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand
+ num, err := strconv.Atoi(strings.TrimSpace(val.Str))
+ if err != nil {
+ return nil
+ }
+ return &num
+ default:
+ // Python int() doesn't accept nulls, arrays or dicts
+ return nil
+ }
+}
+
+func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) {
+ *into = event.PowerLevelsEventContent{
+ Users: make(map[id.UserID]int),
+ UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))),
+ Events: make(map[string]int),
+ EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))),
+ Notifications: nil, // irrelevant for event auth
+ StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")),
+ InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")),
+ KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")),
+ BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")),
+ RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")),
+ }
+ gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val != nil {
+ into.Events[key.Str] = *val
+ }
+ return true
+ })
+ gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool {
+ if key.Type != gjson.String {
+ return false
+ }
+ val := parsePythonInt(value)
+ if val == nil {
+ return false
+ }
+ userID := id.UserID(key.Str)
+ if _, _, err := userID.Parse(); err != nil {
+ return false
+ }
+ into.Users[userID] = *val
+ return true
+ })
+}
diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go
new file mode 100644
index 00000000..d316f3c8
--- /dev/null
+++ b/federation/eventauth/eventauth_internal_test.go
@@ -0,0 +1,66 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+type pythonIntTest struct {
+ Name string
+ Input string
+ Expected int64
+}
+
+var pythonIntTests = []pythonIntTest{
+ {"True", `true`, 1},
+ {"False", `false`, 0},
+ {"SmallFloat", `3.1415`, 3},
+ {"SmallFloatRoundDown", `10.999999999999999`, 10},
+ {"SmallFloatRoundUp", `10.9999999999999999`, 11},
+ {"BigFloatRoundDown", `1000000.9999999999`, 1000000},
+ {"BigFloatRoundUp", `1000000.99999999999`, 1000001},
+ {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992},
+ {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994},
+ {"Int64", `9223372036854775807`, 9223372036854775807},
+ {"Int64String", `"9223372036854775807"`, 9223372036854775807},
+ {"String", `"123"`, 123},
+ {"InvalidFloatInString", `"123.456"`, 0},
+ {"StringWithPlusSign", `"+123"`, 123},
+ {"StringWithMinusSign", `"-123"`, -123},
+ {"StringWithSpaces", `" 123 "`, 123},
+ {"StringWithSpacesAndSign", `" -123 "`, -123},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0},
+ {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0},
+ {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0},
+ //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456},
+}
+
+func TestParsePythonInt(t *testing.T) {
+ for _, test := range pythonIntTests {
+ t.Run(test.Name, func(t *testing.T) {
+ output := parsePythonInt(gjson.Parse(test.Input))
+ if strings.HasPrefix(test.Name, "Invalid") {
+ assert.Nil(t, output)
+ } else {
+ require.NotNil(t, output)
+ assert.Equal(t, int(test.Expected), *output)
+ }
+ })
+ }
+}
diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go
new file mode 100644
index 00000000..e3c5cd76
--- /dev/null
+++ b/federation/eventauth/eventauth_test.go
@@ -0,0 +1,85 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth_test
+
+import (
+ "embed"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/federation/eventauth"
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+//go:embed *.jsonl
+var data embed.FS
+
+type eventMap map[id.EventID]*pdu.PDU
+
+func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) {
+ output := make([]*pdu.PDU, len(ids))
+ for i, evtID := range ids {
+ output[i] = em[evtID]
+ }
+ return output, nil
+}
+
+func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) {
+ return "", time.Time{}, nil
+}
+
+func TestAuthorize(t *testing.T) {
+ files := exerrors.Must(data.ReadDir("."))
+ for _, file := range files {
+ t.Run(file.Name(), func(t *testing.T) {
+ decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name())))
+ events := make(eventMap)
+ var roomVersion *id.RoomVersion
+ for i := 1; ; i++ {
+ var evt *pdu.PDU
+ err := json.UnmarshalDecode(decoder, &evt)
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ require.NoError(t, err)
+ if roomVersion == nil {
+ require.Equal(t, evt.Type, "m.room.create")
+ roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str))
+ }
+ expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str
+ evtID, err := evt.GetEventID(*roomVersion)
+ require.NoError(t, err)
+ require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i)
+
+ // TODO allow redacted events
+ assert.True(t, evt.VerifyContentHash(), i)
+
+ events[evtID] = evt
+ err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey)
+ if err != nil {
+ evt.InternalMeta.Rejected = true
+ }
+ // TODO allow testing intentionally rejected events
+ assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type)
+ }
+ })
+ }
+
+}
diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl
new file mode 100644
index 00000000..2b751de3
--- /dev/null
+++ b/federation/eventauth/testroom-v12-success.jsonl
@@ -0,0 +1,21 @@
+{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}}
+{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}}
+{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}}
+{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}}
+{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}}
+{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
+{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}}
+{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}}
+{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}}
+{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}}
+{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}}
diff --git a/federation/keyserver.go b/federation/keyserver.go
index b0faf8fb..d32ba5cf 100644
--- a/federation/keyserver.go
+++ b/federation/keyserver.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -12,9 +12,13 @@ import (
"strconv"
"time"
- "github.com/gorilla/mux"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
@@ -51,19 +55,25 @@ type KeyServer struct {
}
// Register registers the key server endpoints to the given router.
-func (ks *KeyServer) Register(r *mux.Router) {
- r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
- r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
- keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
- keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
- keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
- keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
- keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w)
- })
- keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w)
- })
+func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) {
+ r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
+ r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
+ keyRouter := http.NewServeMux()
+ keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
+ keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
+ keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+ }
+ r.Handle("/_matrix/key/", exhttp.ApplyMiddleware(
+ keyRouter,
+ exhttp.StripPrefix("/_matrix/key"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
}
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
@@ -157,7 +167,7 @@ type GetQueryKeysResponse struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
- serverName := mux.Vars(r)["serverName"]
+ serverName := r.PathValue("serverName")
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {
diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go
new file mode 100644
index 00000000..16706fe5
--- /dev/null
+++ b/federation/pdu/auth.go
@@ -0,0 +1,71 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "slices"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type StateKey struct {
+ Type string
+ StateKey string
+}
+
+var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token")
+
+type AuthEventSelection []StateKey
+
+func (aes *AuthEventSelection) Add(evtType, stateKey string) {
+ key := StateKey{Type: evtType, StateKey: stateKey}
+ if !aes.Has(key) {
+ *aes = append(*aes, key)
+ }
+}
+
+func (aes *AuthEventSelection) Has(key StateKey) bool {
+ return slices.Contains(*aes, key)
+}
+
+func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ if !roomVersion.RoomIDIsCreateEventID() {
+ keys.Add(event.StateCreate.Type, "")
+ }
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ if membership == event.MembershipJoin && roomVersion.RestrictedJoins() {
+ authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str
+ if authorizedVia != "" {
+ keys.Add(event.StateMember.Type, authorizedVia)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go
new file mode 100644
index 00000000..38ef83e9
--- /dev/null
+++ b/federation/pdu/hash.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *PDU) GetRoomID() (id.RoomID, error) {
+ if pdu == nil {
+ return "", ErrPDUIsNil
+ } else if pdu.Type != "m.room.create" {
+ return "", fmt.Errorf("room ID can only be calculated for m.room.create events")
+ } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() {
+ return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion)
+ } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil {
+ return "", fmt.Errorf("failed to calculate event ID: %w", err)
+ } else {
+ return id.RoomID(evtID), nil
+ }
+}
+
+var UseInternalMetaForGetEventID = false
+
+func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" {
+ return pdu.InternalMeta.EventID, nil
+ }
+ return pdu.calculateEventID(roomVersion, '$')
+}
+
+func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) {
+ referenceHash, err := pdu.GetReferenceHash(roomVersion)
+ if err != nil {
+ return "", err
+ }
+ eventID := make([]byte, 44)
+ eventID[0] = prefix
+ switch roomVersion.EventIDFormat() {
+ case id.EventIDFormatCustom:
+ return "", fmt.Errorf("*pdu.PDU can only be used for room v3+")
+ case id.EventIDFormatBase64:
+ base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:])
+ case id.EventIDFormatURLSafeBase64:
+ base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:])
+ default:
+ return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat())
+ }
+ return id.EventID(eventID), nil
+}
diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go
new file mode 100644
index 00000000..17417e12
--- /dev/null
+++ b/federation/pdu/hash_test.go
@@ -0,0 +1,55 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+)
+
+func TestPDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestPDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testPDUs {
+ if test.redacted {
+ continue
+ }
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestPDU_GetEventID(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion))
+ assert.Equal(t, test.eventID, gotEventID)
+ })
+ }
+}
diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go
new file mode 100644
index 00000000..17db6995
--- /dev/null
+++ b/federation/pdu/pdu.go
@@ -0,0 +1,156 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "bytes"
+ "crypto/ed25519"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/jsonbytes"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU.
+//
+// The input time is the timestamp of the event. The function should attempt to fetch a key that is
+// valid at or after this time, but if that is not possible, the latest available key should be
+// returned without an error. The verify function will do its own validity checking based on the
+// returned valid until timestamp.
+type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
+
+type AnyPDU interface {
+ GetRoomID() (id.RoomID, error)
+ GetEventID(roomVersion id.RoomVersion) (id.EventID, error)
+ GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error)
+ CalculateContentHash() ([32]byte, error)
+ FillContentHash() error
+ VerifyContentHash() bool
+ Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error
+ VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error
+ ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error)
+ AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection)
+}
+
+var (
+ _ AnyPDU = (*PDU)(nil)
+ _ AnyPDU = (*RoomV1PDU)(nil)
+)
+
+type InternalMeta struct {
+ EventID id.EventID `json:"event_id,omitempty"`
+ Rejected bool `json:"rejected,omitempty"`
+ Extra map[string]any `json:",unknown"`
+}
+
+type PDU struct {
+ AuthEvents []id.EventID `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []id.EventID `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+ InternalMeta InternalMeta `json:"-"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+var ErrPDUIsNil = errors.New("PDU is nil")
+
+type Hashes struct {
+ SHA256 jsonbytes.UnpaddedBytes `json:"sha256"`
+
+ Unknown jsontext.Value `json:",unknown"`
+}
+
+func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if pdu.Type == "m.room.create" && roomVersion == "" {
+ roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ eventID, err := pdu.GetEventID(roomVersion)
+ if err != nil {
+ return nil, err
+ }
+ roomID := pdu.RoomID
+ if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() {
+ roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1))
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: eventID,
+ RoomID: roomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err = json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) {
+ if signature == "" {
+ return
+ }
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = signature
+}
+
+func marshalCanonical(data any) (jsontext.Value, error) {
+ marshaledBytes, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+ marshaled := jsontext.Value(marshaledBytes)
+ err = marshaled.Canonicalize()
+ if err != nil {
+ return nil, err
+ }
+ check := canonicaljson.CanonicalJSONAssumeValid(marshaled)
+ if !bytes.Equal(marshaled, check) {
+ fmt.Println(string(marshaled))
+ fmt.Println(string(check))
+ return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled))
+ }
+ return marshaled, nil
+}
diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go
new file mode 100644
index 00000000..59d7c3a6
--- /dev/null
+++ b/federation/pdu/pdu_test.go
@@ -0,0 +1,193 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/json/v2"
+ "time"
+
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+type serverKey struct {
+ key id.SigningKey
+ validUntilTS time.Time
+}
+
+type serverDetails struct {
+ serverName string
+ keys map[id.KeyID]serverKey
+}
+
+func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ if serverName != sd.serverName {
+ return "", time.Time{}, nil
+ }
+ key, ok := sd.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+}
+
+var mauniumNet = serverDetails{
+ serverName: "maunium.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_xxeS": {
+ key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var envsNet = serverDetails{
+ serverName: "envs.net",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_zIqy": {
+ key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0",
+ validUntilTS: time.UnixMilli(1722360538068),
+ },
+ "ed25519:wuJyKT": {
+ key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var matrixOrg = serverDetails{
+ serverName: "matrix.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:auto": {
+ key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw",
+ validUntilTS: time.UnixMilli(1576767829750),
+ },
+ "ed25519:a_RXGa": {
+ key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var continuwuityOrg = serverDetails{
+ serverName: "continuwuity.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:PwHlNsFu": {
+ key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+var novaAstraltechOrg = serverDetails{
+ serverName: "nova.astraltech.org",
+ keys: map[id.KeyID]serverKey{
+ "ed25519:a_afpo": {
+ key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o",
+ validUntilTS: time.Now(),
+ },
+ },
+}
+
+type testPDU struct {
+ name string
+ pdu string
+ eventID id.EventID
+ roomVersion id.RoomVersion
+ redacted bool
+ serverDetails
+}
+
+var roomV4MessageTestPDU = testPDU{
+ name: "m.room.message in v4 room",
+ pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`,
+ eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}
+
+var roomV12MessageTestPDU = testPDU{
+ name: "m.room.message in v12 room",
+ pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`,
+ eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}
+
+var testPDUs = []testPDU{roomV4MessageTestPDU, {
+ name: "m.room.message in v5 room",
+ pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`,
+ eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA",
+ roomVersion: id.RoomV5,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v10 room",
+ pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`,
+ eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.message in v11 room",
+ pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`,
+ eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`,
+ roomVersion: id.RoomV11,
+ serverDetails: mauniumNet,
+}, roomV12MessageTestPDU, {
+ name: "m.room.create in v4 room",
+ pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`,
+ eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY",
+ roomVersion: id.RoomV4,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.create in v10 room",
+ pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`,
+ eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ",
+ roomVersion: id.RoomV10,
+ serverDetails: envsNet,
+}, {
+ name: "m.room.create in v12 room",
+ pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`,
+ eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v4 room",
+ pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`,
+ eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member in v10 room",
+ pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`,
+ eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs",
+ roomVersion: id.RoomV10,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.member of creator in v12 room",
+ pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`,
+ eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg",
+ roomVersion: id.RoomV12,
+ serverDetails: mauniumNet,
+}, {
+ name: "custom message event in v4 room",
+ pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`,
+ eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0",
+ roomVersion: id.RoomV4,
+ serverDetails: mauniumNet,
+}, {
+ name: "redacted m.room.member event in v11 room with 2 signatures",
+ pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`,
+ eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ",
+ roomVersion: id.RoomV11,
+ serverDetails: novaAstraltechOrg,
+ redacted: true,
+}}
+
+func parsePDU(pdu string) (out *pdu.PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go
new file mode 100644
index 00000000..d7ee0c15
--- /dev/null
+++ b/federation/pdu/redact.go
@@ -0,0 +1,111 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "encoding/json/jsontext"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/id"
+)
+
+func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value {
+ filtered := jsontext.Value("{}")
+ var err error
+ for _, path := range allowedPaths {
+ res := gjson.GetBytes(object, path)
+ if res.Exists() {
+ var raw jsontext.Value
+ if res.Index > 0 {
+ raw = object[res.Index : res.Index+len(res.Raw)]
+ } else {
+ raw = jsontext.Value(res.Raw)
+ }
+ filtered, err = sjson.SetRawBytes(filtered, path, raw)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ return filtered
+}
+
+func (pdu *PDU) Clone() *PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+var emptyObject = jsontext.Value("{}")
+
+func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value {
+ switch eventType {
+ case "m.room.member":
+ allowedPaths := []string{"membership"}
+ if roomVersion.RestrictedJoinsFix() {
+ allowedPaths = append(allowedPaths, "join_authorised_via_users_server")
+ }
+ if roomVersion.UpdatedRedactionRules() {
+ allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed"))
+ }
+ return filteredObject(content, allowedPaths...)
+ case "m.room.create":
+ if !roomVersion.UpdatedRedactionRules() {
+ return filteredObject(content, "creator")
+ }
+ return content
+ case "m.room.join_rules":
+ if roomVersion.RestrictedJoins() {
+ return filteredObject(content, "join_rule", "allow")
+ }
+ return filteredObject(content, "join_rule")
+ case "m.room.power_levels":
+ allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"}
+ if roomVersion.UpdatedRedactionRules() {
+ allowedKeys = append(allowedKeys, "invite")
+ }
+ return filteredObject(content, allowedKeys...)
+ case "m.room.history_visibility":
+ return filteredObject(content, "history_visibility")
+ case "m.room.redaction":
+ if roomVersion.RedactsInContent() {
+ return filteredObject(content, "redacts")
+ }
+ return emptyObject
+ case "m.room.aliases":
+ if roomVersion.SpecialCasedAliasesAuth() {
+ return filteredObject(content, "aliases")
+ }
+ return emptyObject
+ default:
+ return emptyObject
+ }
+}
+
+func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if roomVersion.UpdatedRedactionRules() {
+ pdu.DeprecatedPrevState = nil
+ pdu.DeprecatedOrigin = nil
+ pdu.DeprecatedMembership = nil
+ }
+ if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go
new file mode 100644
index 00000000..04e7c5ef
--- /dev/null
+++ b/federation/pdu/signature.go
@@ -0,0 +1,60 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "fmt"
+ "time"
+
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature))
+ return nil
+}
+
+func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, validUntil, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() {
+ return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go
new file mode 100644
index 00000000..01df5076
--- /dev/null
+++ b/federation/pdu/signature_test.go
@@ -0,0 +1,102 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+func TestPDU_VerifySignature(t *testing.T) {
+ for _, test := range testPDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.NoError(t, err)
+ })
+ }
+}
+
+func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
+ test := roomV4MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.NoError(t, err)
+}
+
+func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ key = test.keys[keyID].key
+ validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ return
+ })
+ assert.Error(t, err)
+}
+
+func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) {
+ test := roomV12MessageTestPDU
+ parsed := parsePDU(test.pdu)
+ for _, sigs := range parsed.Signatures {
+ for key := range sigs {
+ sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC"
+ }
+ }
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
+ assert.Error(t, err)
+}
+
+func TestPDU_Sign(t *testing.T) {
+ pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil))
+ evt := &pdu.PDU{
+ AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"},
+ Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`),
+ Depth: 123,
+ OriginServerTS: 1755384351627,
+ PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"},
+ RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ Sender: "@tulir:example.com",
+ Type: "m.room.message",
+ }
+ err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
+ require.NoError(t, err)
+ err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
+ if serverName == "example.com" && keyID == "ed25519:rand" {
+ key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
+ validUntil = time.Now()
+ }
+ return
+ })
+ require.NoError(t, err)
+
+}
diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go
new file mode 100644
index 00000000..9557f8ab
--- /dev/null
+++ b/federation/pdu/v1.go
@@ -0,0 +1,277 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu
+
+import (
+ "crypto/ed25519"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json/jsontext"
+ "encoding/json/v2"
+ "fmt"
+ "time"
+
+ "github.com/tidwall/gjson"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/federation/signutil"
+ "maunium.net/go/mautrix/id"
+)
+
+type V1EventReference struct {
+ ID id.EventID
+ Hashes Hashes
+}
+
+var (
+ _ json.UnmarshalerFrom = (*V1EventReference)(nil)
+ _ json.MarshalerTo = (*V1EventReference)(nil)
+)
+
+func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error {
+ return json.MarshalEncode(enc, []any{er.ID, er.Hashes})
+}
+
+func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
+ var ref V1EventReference
+ var data []jsontext.Value
+ if err := json.UnmarshalDecode(dec, &data); err != nil {
+ return err
+ } else if len(data) != 2 {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data))
+ } else if err = json.Unmarshal(data[0], &ref.ID); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err)
+ } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil {
+ return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err)
+ }
+ *er = ref
+ return nil
+}
+
+type RoomV1PDU struct {
+ AuthEvents []V1EventReference `json:"auth_events"`
+ Content jsontext.Value `json:"content"`
+ Depth int64 `json:"depth"`
+ EventID id.EventID `json:"event_id"`
+ Hashes *Hashes `json:"hashes,omitzero"`
+ OriginServerTS int64 `json:"origin_server_ts"`
+ PrevEvents []V1EventReference `json:"prev_events"`
+ Redacts *id.EventID `json:"redacts,omitzero"`
+ RoomID id.RoomID `json:"room_id"`
+ Sender id.UserID `json:"sender"`
+ Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
+ StateKey *string `json:"state_key,omitzero"`
+ Type string `json:"type"`
+ Unsigned jsontext.Value `json:"unsigned,omitzero"`
+
+ Unknown jsontext.Value `json:",unknown"`
+
+ // Deprecated legacy fields
+ DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
+ DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
+ DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
+}
+
+func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) {
+ return pdu.RoomID, nil
+}
+
+func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion)
+ }
+ return pdu.EventID, nil
+}
+
+func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Signatures = nil
+ return pdu.Redact(roomVersion)
+}
+
+func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU {
+ pdu.Unknown = nil
+ pdu.Unsigned = nil
+ if pdu.Type != "m.room.redaction" {
+ pdu.Redacts = nil
+ }
+ pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
+ return pdu
+}
+
+func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion)
+ }
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
+ if err := pdu.FillContentHash(); err != nil {
+ return [32]byte{}, err
+ }
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) {
+ if pdu == nil {
+ return [32]byte{}, ErrPDUIsNil
+ }
+ pduClone := pdu.Clone()
+ pduClone.Signatures = nil
+ pduClone.Unsigned = nil
+ pduClone.Hashes = nil
+ rawJSON, err := marshalCanonical(pduClone)
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
+ }
+ return sha256.Sum256(rawJSON), nil
+}
+
+func (pdu *RoomV1PDU) FillContentHash() error {
+ if pdu == nil {
+ return ErrPDUIsNil
+ } else if pdu.Hashes != nil {
+ return nil
+ } else if hash, err := pdu.CalculateContentHash(); err != nil {
+ return err
+ } else {
+ pdu.Hashes = &Hashes{SHA256: hash[:]}
+ return nil
+ }
+}
+
+func (pdu *RoomV1PDU) VerifyContentHash() bool {
+ if pdu == nil || pdu.Hashes == nil {
+ return false
+ }
+ calculatedHash, err := pdu.CalculateContentHash()
+ if err != nil {
+ return false
+ }
+ return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
+}
+
+func (pdu *RoomV1PDU) Clone() *RoomV1PDU {
+ return ptr.Clone(pdu)
+}
+
+func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion)
+ }
+ err := pdu.FillContentHash()
+ if err != nil {
+ return err
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
+ }
+ signature := ed25519.Sign(privateKey, rawJSON)
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature)
+ return nil
+}
+
+func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion)
+ }
+ rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
+ if err != nil {
+ return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
+ }
+ verified := false
+ for keyID, sig := range pdu.Signatures[serverName] {
+ originServerTS := time.UnixMilli(pdu.OriginServerTS)
+ key, _, err := getKey(serverName, keyID, originServerTS)
+ if err != nil {
+ return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
+ } else if key == "" {
+ return fmt.Errorf("key %s not found for %s", keyID, serverName)
+ } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
+ return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
+ } else {
+ verified = true
+ }
+ }
+ if !verified {
+ return fmt.Errorf("no verifiable signatures found for server %s", serverName)
+ }
+ return nil
+}
+
+func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool {
+ switch roomVersion {
+ case id.RoomV0, id.RoomV1, id.RoomV2:
+ return true
+ default:
+ return false
+ }
+}
+
+func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
+ if !pdu.SupportsRoomVersion(roomVersion) {
+ return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion)
+ }
+ evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
+ if pdu.StateKey != nil {
+ evtType.Class = event.StateEventType
+ }
+ evt := &event.Event{
+ StateKey: pdu.StateKey,
+ Sender: pdu.Sender,
+ Type: evtType,
+ Timestamp: pdu.OriginServerTS,
+ ID: pdu.EventID,
+ RoomID: pdu.RoomID,
+ Redacts: ptr.Val(pdu.Redacts),
+ }
+ err := json.Unmarshal(pdu.Content, &evt.Content)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ }
+ return evt, nil
+}
+
+func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) {
+ if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
+ return AuthEventSelection{}
+ }
+ keys = make(AuthEventSelection, 0, 3)
+ keys.Add(event.StateCreate.Type, "")
+ keys.Add(event.StatePowerLevels.Type, "")
+ keys.Add(event.StateMember.Type, pdu.Sender.String())
+ if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
+ keys.Add(event.StateMember.Type, *pdu.StateKey)
+ membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
+ if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
+ keys.Add(event.StateJoinRules.Type, "")
+ }
+ if membership == event.MembershipInvite {
+ thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
+ if thirdPartyInviteToken != "" {
+ keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
+ }
+ }
+ }
+ return
+}
diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go
new file mode 100644
index 00000000..ecf2dbd2
--- /dev/null
+++ b/federation/pdu/v1_test.go
@@ -0,0 +1,86 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package pdu_test
+
+import (
+ "encoding/base64"
+ "encoding/json/v2"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/federation/pdu"
+ "maunium.net/go/mautrix/id"
+)
+
+var testV1PDUs = []testPDU{{
+ name: "m.room.message in v1 room",
+ pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`,
+ eventID: "$16738169022163bokdi:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}, {
+ name: "m.room.create in v1 room",
+ pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`,
+ eventID: "$143454825711DhCxH:matrix.org",
+ roomVersion: id.RoomV1,
+ serverDetails: matrixOrg,
+}, {
+ name: "m.room.member in v1 room",
+ pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`,
+ eventID: "$15660585693272iEryv:maunium.net",
+ roomVersion: id.RoomV1,
+ serverDetails: mauniumNet,
+}}
+
+func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) {
+ exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
+ return
+}
+
+func TestRoomV1PDU_CalculateContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ contentHash := exerrors.Must(parsed.CalculateContentHash())
+ assert.Equal(
+ t,
+ base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
+ base64.RawStdEncoding.EncodeToString(contentHash[:]),
+ )
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifyContentHash(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ assert.True(t, parsed.VerifyContentHash())
+ })
+ }
+}
+
+func TestRoomV1PDU_VerifySignature(t *testing.T) {
+ for _, test := range testV1PDUs {
+ t.Run(test.name, func(t *testing.T) {
+ parsed := parseV1PDU(test.pdu)
+ err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
+ key, ok := test.keys[keyID]
+ if ok {
+ return key.key, key.validUntilTS, nil
+ }
+ return "", time.Time{}, nil
+ })
+ assert.NoError(t, err)
+ })
+ }
+}
diff --git a/federation/resolution.go b/federation/resolution.go
index 69d4d3bf..a3188266 100644
--- a/federation/resolution.go
+++ b/federation/resolution.go
@@ -20,6 +20,8 @@ import (
"time"
"github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
)
type ResolvedServerName struct {
@@ -78,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
- hostname, port, ok = ParseServerName(wellKnown.Server)
+ wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
+ if ok {
+ hostname, port = wkHost, wkPort
+ }
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
@@ -171,9 +176,11 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
+ } else if resp.ContentLength > mautrix.WellKnownMaxSize {
+ return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
}
var respData RespWellKnown
- err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
+ err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {
diff --git a/federation/serverauth.go b/federation/serverauth.go
index f46c7991..cd300341 100644
--- a/federation/serverauth.go
+++ b/federation/serverauth.go
@@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res
}
err = (&signableRequest{
Method: r.Method,
- URI: r.URL.EscapedPath(),
+ URI: r.URL.RequestURI(),
Origin: parsed.Origin,
Destination: destination,
Content: reqBody,
diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go
index 9fa15459..f99fc6cf 100644
--- a/federation/serverauth_test.go
+++ b/federation/serverauth_test.go
@@ -19,9 +19,9 @@ import (
func TestServerKeyResponse_VerifySelfSignature(t *testing.T) {
cli := federation.NewClient("", nil, nil)
ctx := context.Background()
- for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} {
+ for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} {
t.Run(name, func(t *testing.T) {
- resp, err := cli.ServerKeys(ctx, "matrix.org")
+ resp, err := cli.ServerKeys(ctx, name)
require.NoError(t, err)
assert.NoError(t, resp.VerifySelfSignature())
})
diff --git a/federation/signingkey.go b/federation/signingkey.go
index 0ae6a571..a4ad9679 100644
--- a/federation/signingkey.go
+++ b/federation/signingkey.go
@@ -10,17 +10,15 @@ import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
- "errors"
"fmt"
"strings"
"time"
- "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
- "go.mau.fi/util/exgjson"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -35,8 +33,8 @@ type SigningKey struct {
//
// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
func (sk *SigningKey) SynapseString() string {
- alg, id := sk.ID.Parse()
- return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
+ alg, keyID := sk.ID.Parse()
+ return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
}
// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
@@ -100,56 +98,13 @@ func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool {
func (skr *ServerKeyResponse) VerifySelfSignature() error {
for keyID, key := range skr.VerifyKeys {
- if err := VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
+ if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err)
}
}
return nil
}
-func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
- var err error
- message, ok := data.(json.RawMessage)
- if !ok {
- message, err = json.Marshal(data)
- if err != nil {
- return fmt.Errorf("failed to marshal data: %w", err)
- }
- }
- sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
- if sigVal.Type != gjson.String {
- return ErrSignatureNotFound
- }
- message, err = sjson.DeleteBytes(message, "signatures")
- if err != nil {
- return fmt.Errorf("failed to delete signatures: %w", err)
- }
- message, err = sjson.DeleteBytes(message, "unsigned")
- if err != nil {
- return fmt.Errorf("failed to delete unsigned: %w", err)
- }
- return VerifyJSONRaw(key, sigVal.Str, message)
-}
-
-var ErrSignatureNotFound = errors.New("signature not found")
-var ErrInvalidSignature = errors.New("invalid signature")
-
-func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
- sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
- if err != nil {
- return fmt.Errorf("failed to decode signature: %w", err)
- }
- keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
- if err != nil {
- return fmt.Errorf("failed to decode key: %w", err)
- }
- message = canonicaljson.CanonicalJSONAssumeValid(message)
- if !ed25519.Verify(keyBytes, message, sigBytes) {
- return ErrInvalidSignature
- }
- return nil
-}
-
type marshalableSKR ServerKeyResponse
func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error {
diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go
new file mode 100644
index 00000000..ea0e7886
--- /dev/null
+++ b/federation/signutil/verify.go
@@ -0,0 +1,106 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package signutil
+
+import (
+ "crypto/ed25519"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
+
+ "maunium.net/go/mautrix/crypto/canonicaljson"
+ "maunium.net/go/mautrix/id"
+)
+
+var ErrSignatureNotFound = errors.New("signature not found")
+var ErrInvalidSignature = errors.New("invalid signature")
+
+func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
+ if sigVal.Type != gjson.String {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ return VerifyJSONRaw(key, sigVal.Str, message)
+}
+
+func VerifyJSONAny(key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigs := gjson.GetBytes(message, "signatures")
+ if !sigs.IsObject() {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ var validated bool
+ sigs.ForEach(func(_, value gjson.Result) bool {
+ if !value.IsObject() {
+ return true
+ }
+ value.ForEach(func(_, value gjson.Result) bool {
+ if value.Type != gjson.String {
+ return true
+ }
+ validated = VerifyJSONRaw(key, value.Str, message) == nil
+ return !validated
+ })
+ return !validated
+ })
+ if !validated {
+ return ErrInvalidSignature
+ }
+ return nil
+}
+
+func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
+ sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
+ if err != nil {
+ return fmt.Errorf("failed to decode signature: %w", err)
+ }
+ keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
+ if err != nil {
+ return fmt.Errorf("failed to decode key: %w", err)
+ }
+ message = canonicaljson.CanonicalJSONAssumeValid(message)
+ if !ed25519.Verify(keyBytes, message, sigBytes) {
+ return ErrInvalidSignature
+ }
+ return nil
+}
diff --git a/filter.go b/filter.go
index c6c8211b..54973dab 100644
--- a/filter.go
+++ b/filter.go
@@ -57,7 +57,7 @@ type FilterPart struct {
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
- return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
+ return errors.New("bad event_format value")
}
return nil
}
diff --git a/format/htmlparser.go b/format/htmlparser.go
index f9d51e39..e0507d93 100644
--- a/format/htmlparser.go
+++ b/format/htmlparser.go
@@ -13,6 +13,7 @@ import (
"strconv"
"strings"
+ "go.mau.fi/util/exstrings"
"golang.org/x/net/html"
"maunium.net/go/mautrix/event"
@@ -92,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string
}
}
+func onlyBacktickCount(line string) (count int) {
+ for i := 0; i < len(line); i++ {
+ if line[i] != '`' {
+ return -1
+ }
+ count++
+ }
+ return
+}
+
+func DefaultMonospaceBlockConverter(code, language string, ctx Context) string {
+ if len(code) == 0 || code[len(code)-1] != '\n' {
+ code += "\n"
+ }
+ fence := "```"
+ for line := range strings.SplitSeq(code, "\n") {
+ count := onlyBacktickCount(strings.TrimSpace(line))
+ if count >= len(fence) {
+ fence = strings.Repeat("`", count+1)
+ }
+ }
+ return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence)
+}
+
// HTMLParser is a somewhat customizable Matrix HTML parser.
type HTMLParser struct {
PillConverter PillConverter
@@ -286,7 +311,10 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
}
if parser.LinkConverter != nil {
return parser.LinkConverter(str, href, ctx)
- } else if str == href {
+ } else if str == href ||
+ str == strings.TrimPrefix(href, "mailto:") ||
+ str == strings.TrimPrefix(href, "http://") ||
+ str == strings.TrimPrefix(href, "https://") {
return str
}
return fmt.Sprintf("%s (%s)", str, href)
@@ -344,10 +372,7 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
}
- if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' {
- preStr += "\n"
- }
- return fmt.Sprintf("```%s\n%s```", language, preStr)
+ return DefaultMonospaceBlockConverter(preStr, language, ctx)
default:
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
@@ -368,7 +393,7 @@ func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) Tagge
switch node.Type {
case html.TextNode:
if !ctx.PreserveWhitespace {
- node.Data = strings.Replace(node.Data, "\n", "", -1)
+ node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", ""))
}
if parser.TextConverter != nil {
node.Data = parser.TextConverter(node.Data, ctx)
diff --git a/format/markdown.go b/format/markdown.go
index 3d9979b4..77ced0dc 100644
--- a/format/markdown.go
+++ b/format/markdown.go
@@ -57,7 +57,18 @@ type uriAble interface {
}
func MarkdownMention(id uriAble) string {
- return MarkdownLink(id.String(), id.URI().MatrixToURL())
+ return MarkdownMentionWithName(id.String(), id)
+}
+
+func MarkdownMentionWithName(name string, id uriAble) string {
+ return MarkdownLink(name, id.URI().MatrixToURL())
+}
+
+func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string {
+ if name == "" {
+ name = id.String()
+ }
+ return MarkdownLink(name, id.URI(via...).MatrixToURL())
}
func MarkdownLink(name string, url string) string {
diff --git a/go.mod b/go.mod
index 59f29c0c..49a1d4e4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,43 +1,42 @@
module maunium.net/go/mautrix
-go 1.23.0
+go 1.25.0
-toolchain go1.24.4
+toolchain go1.26.0
require (
- filippo.io/edwards25519 v1.1.0
+ filippo.io/edwards25519 v1.2.0
github.com/chzyer/readline v1.5.1
- github.com/gorilla/mux v1.8.0
- github.com/gorilla/websocket v1.5.0
- github.com/lib/pq v1.10.9
- github.com/mattn/go-sqlite3 v1.14.28
+ github.com/coder/websocket v1.8.14
+ github.com/lib/pq v1.11.2
+ github.com/mattn/go-sqlite3 v1.14.34
github.com/rs/xid v1.6.0
github.com/rs/zerolog v1.34.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
- github.com/stretchr/testify v1.10.0
+ github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
- github.com/yuin/goldmark v1.7.12
- go.mau.fi/util v0.8.8
- go.mau.fi/zeroconfig v0.1.3
- golang.org/x/crypto v0.40.0
- golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc
- golang.org/x/net v0.42.0
- golang.org/x/sync v0.16.0
+ github.com/yuin/goldmark v1.7.16
+ go.mau.fi/util v0.9.6
+ go.mau.fi/zeroconfig v0.2.0
+ golang.org/x/crypto v0.48.0
+ golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
+ golang.org/x/net v0.50.0
+ golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
)
require (
- github.com/coreos/go-systemd/v22 v22.5.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
- github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect
+ github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
- golang.org/x/sys v0.34.0 // indirect
- golang.org/x/text v0.27.0 // indirect
+ golang.org/x/sys v0.41.0 // indirect
+ golang.org/x/text v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)
diff --git a/go.sum b/go.sum
index 9f48386e..871a5156 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
-filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
+filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
@@ -8,17 +8,16 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
-github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
+github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
+github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
+github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
-github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
-github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
-github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
+github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -26,10 +25,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
-github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4=
-github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
+github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
+github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -39,8 +38,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -51,28 +50,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY=
-github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
-go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c=
-go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c=
-go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
-go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
-golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
-golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
-golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
-golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
-golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
-golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
-golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
-golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
+github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
+go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
+go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
+go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU=
+go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w=
+golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
+golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
+golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
+golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
-golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
-golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
-golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
+golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
+golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
+golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
diff --git a/id/contenturi.go b/id/contenturi.go
index e6a313f5..67127b6c 100644
--- a/id/contenturi.go
+++ b/id/contenturi.go
@@ -17,8 +17,14 @@ import (
)
var (
- InvalidContentURI = errors.New("invalid Matrix content URI")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrInvalidContentURI = errors.New("invalid Matrix content URI")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ InvalidContentURI = ErrInvalidContentURI
+ InputNotJSONString = ErrInputNotJSONString
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
- return InputNotJSONString
+ return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString)
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {
diff --git a/id/crypto.go b/id/crypto.go
index 355a84a8..ee857f78 100644
--- a/id/crypto.go
+++ b/id/crypto.go
@@ -53,6 +53,34 @@ const (
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
)
+type KeySource string
+
+func (source KeySource) String() string {
+ return string(source)
+}
+
+func (source KeySource) Int() int {
+ switch source {
+ case KeySourceDirect:
+ return 100
+ case KeySourceBackup:
+ return 90
+ case KeySourceImport:
+ return 80
+ case KeySourceForward:
+ return 50
+ default:
+ return 0
+ }
+}
+
+const (
+ KeySourceDirect KeySource = "direct"
+ KeySourceBackup KeySource = "backup"
+ KeySourceImport KeySource = "import"
+ KeySourceForward KeySource = "forward"
+)
+
// BackupVersion is an arbitrary string that identifies a server side key backup.
type KeyBackupVersion string
diff --git a/id/matrixuri.go b/id/matrixuri.go
index 8f5ec849..d5c78bc7 100644
--- a/id/matrixuri.go
+++ b/id/matrixuri.go
@@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
- if uri.Via != nil && len(uri.Via) > 0 {
+ if len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {
diff --git a/id/opaque.go b/id/opaque.go
index 1d9f0dcf..c1ad4988 100644
--- a/id/opaque.go
+++ b/id/opaque.go
@@ -32,6 +32,9 @@ type EventID string
// https://github.com/matrix-org/matrix-doc/pull/2716
type BatchID string
+// A DelayID is a string identifying a delayed event.
+type DelayID string
+
func (roomID RoomID) String() string {
return string(roomID)
}
diff --git a/id/roomversion.go b/id/roomversion.go
new file mode 100644
index 00000000..578c10bd
--- /dev/null
+++ b/id/roomversion.go
@@ -0,0 +1,265 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package id
+
+import (
+ "errors"
+ "fmt"
+ "slices"
+)
+
+type RoomVersion string
+
+const (
+ RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1
+ RoomV1 RoomVersion = "1"
+ RoomV2 RoomVersion = "2"
+ RoomV3 RoomVersion = "3"
+ RoomV4 RoomVersion = "4"
+ RoomV5 RoomVersion = "5"
+ RoomV6 RoomVersion = "6"
+ RoomV7 RoomVersion = "7"
+ RoomV8 RoomVersion = "8"
+ RoomV9 RoomVersion = "9"
+ RoomV10 RoomVersion = "10"
+ RoomV11 RoomVersion = "11"
+ RoomV12 RoomVersion = "12"
+)
+
+func (rv RoomVersion) Equals(versions ...RoomVersion) bool {
+ return slices.Contains(versions, rv)
+}
+
+func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool {
+ return !rv.Equals(versions...)
+}
+
+var ErrUnknownRoomVersion = errors.New("unknown room version")
+
+func (rv RoomVersion) unknownVersionError() error {
+ return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv)
+}
+
+func (rv RoomVersion) IsKnown() bool {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12:
+ return true
+ default:
+ return false
+ }
+}
+
+type StateResVersion int
+
+const (
+ // StateResV1 is the original state resolution algorithm.
+ StateResV1 StateResVersion = 0
+ // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759
+ StateResV2 StateResVersion = 1
+ // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297
+ StateResV2_1 StateResVersion = 2
+)
+
+// StateResVersion returns the version of the state resolution algorithm used by this room version.
+func (rv RoomVersion) StateResVersion() StateResVersion {
+ switch rv {
+ case RoomV0, RoomV1:
+ return StateResV1
+ case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11:
+ return StateResV2
+ case RoomV12:
+ return StateResV2_1
+ default:
+ panic(rv.unknownVersionError())
+ }
+}
+
+type EventIDFormat int
+
+const (
+ // EventIDFormatCustom is the original format used by room v1 and v2.
+ // Event IDs in this format are an arbitrary string followed by a colon and the server name.
+ EventIDFormatCustom EventIDFormat = 0
+ // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659.
+ // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatBase64 EventIDFormat = 1
+ // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002.
+ // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event.
+ EventIDFormatURLSafeBase64 EventIDFormat = 2
+)
+
+// EventIDFormat returns the format of event IDs used by this room version.
+func (rv RoomVersion) EventIDFormat() EventIDFormat {
+ switch rv {
+ case RoomV0, RoomV1, RoomV2:
+ return EventIDFormatCustom
+ case RoomV3:
+ return EventIDFormatBase64
+ default:
+ return EventIDFormatURLSafeBase64
+ }
+}
+
+/////////////////////
+// Room v5 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2077
+
+// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys
+// must be enforced on received events.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076
+func (rv RoomVersion) EnforceSigningKeyValidity() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4)
+}
+
+/////////////////////
+// Room v6 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2240
+
+// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased
+// to only always allow servers to modify the state event with their own server name as state key.
+// This also implies that the `aliases` field is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432
+func (rv RoomVersion) SpecialCasedAliasesAuth() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540
+func (rv RoomVersion) ForbidFloatsAndBigInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth.
+// However, the field is not protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209
+func (rv RoomVersion) NotificationsPowerLevels() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
+}
+
+/////////////////////
+// Room v7 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/2998
+
+// Knocks returns true if the `knock` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403
+func (rv RoomVersion) Knocks() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6)
+}
+
+/////////////////////
+// Room v8 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3289
+
+// RestrictedJoins returns true if the `restricted` join rule is supported.
+// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083
+func (rv RoomVersion) RestrictedJoins() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7)
+}
+
+/////////////////////
+// Room v9 changes //
+/////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+
+// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375
+func (rv RoomVersion) RestrictedJoinsFix() bool {
+ return rv.RestrictedJoins() && rv != RoomV8
+}
+
+//////////////////////
+// Room v10 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3604
+
+// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings).
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667
+func (rv RoomVersion) ValidatePowerLevelInts() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+// KnockRestricted returns true if the `knock_restricted` join rule is supported.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787
+func (rv RoomVersion) KnockRestricted() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
+}
+
+//////////////////////
+// Room v11 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3820
+
+// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175
+func (rv RoomVersion) CreatorInContent() bool {
+ return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level.
+// The redaction protection is also moved from the top level to the content field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174
+// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection).
+func (rv RoomVersion) RedactsInContent() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied.
+//
+// Specifically:
+//
+// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected.
+// * the entire content of `m.room.create` is protected.
+// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field.
+// * the `m.room.power_levels` event protects the `invite` field in content.
+// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176,
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and
+// https://github.com/matrix-org/matrix-spec-proposals/pull/3989
+func (rv RoomVersion) UpdatedRedactionRules() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
+}
+
+//////////////////////
+// Room v12 changes //
+//////////////////////
+// https://github.com/matrix-org/matrix-spec-proposals/pull/4304
+
+// Return value of StateResVersion was changed to StateResV2_1
+
+// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level.
+// This also implies that the `m.room.create` event has an `additional_creators` field,
+// and that the creators can't be present in the `m.room.power_levels` event.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289
+func (rv RoomVersion) PrivilegedRoomCreators() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
+
+// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event.
+// This also implies that `m.room.create` events do not have a `room_id` field.
+//
+// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291
+func (rv RoomVersion) RoomIDIsCreateEventID() bool {
+ return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
+}
diff --git a/id/trust.go b/id/trust.go
index 04f6e36b..6255093e 100644
--- a/id/trust.go
+++ b/id/trust.go
@@ -16,6 +16,7 @@ type TrustState int
const (
TrustStateBlacklisted TrustState = -100
+ TrustStateDeviceKeyMismatch TrustState = -5
TrustStateUnset TrustState = 0
TrustStateUnknownDevice TrustState = 10
TrustStateForwarded TrustState = 20
@@ -23,7 +24,7 @@ const (
TrustStateCrossSignedTOFU TrustState = 100
TrustStateCrossSignedVerified TrustState = 200
TrustStateVerified TrustState = 300
- TrustStateInvalid TrustState = (1 << 31) - 1
+ TrustStateInvalid TrustState = -2147483647
)
func (ts *TrustState) UnmarshalText(data []byte) error {
@@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState {
switch strings.ToLower(val) {
case "blacklisted":
return TrustStateBlacklisted
+ case "device-key-mismatch":
+ return TrustStateDeviceKeyMismatch
case "unverified":
return TrustStateUnset
case "cross-signed-untrusted":
@@ -67,6 +70,8 @@ func (ts TrustState) String() string {
switch ts {
case TrustStateBlacklisted:
return "blacklisted"
+ case TrustStateDeviceKeyMismatch:
+ return "device-key-mismatch"
case TrustStateUnset:
return "unverified"
case TrustStateCrossSignedUntrusted:
diff --git a/id/userid.go b/id/userid.go
index 6d9f4080..726a0d58 100644
--- a/id/userid.go
+++ b/id/userid.go
@@ -104,16 +104,24 @@ func ValidateUserLocalpart(localpart string) error {
return nil
}
-// ParseAndValidate parses the user ID into the localpart and server name like Parse,
-// and also validates that the localpart is allowed according to the user identifiers spec.
-func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.Parse()
+// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
+// This should be used with care: there are real users still using historical localparts.
+func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
+ localpart, homeserver, err = userID.ParseAndValidateRelaxed()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
- if err == nil && len(userID) > UserIDMaxLength {
+ return
+}
+
+// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
+// and also validates that the user ID is not too long and that the server name is valid.
+func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
+ if len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
+ return
}
+ localpart, homeserver, err = userID.Parse()
if err == nil && !ValidateServerName(homeserver) {
err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
}
@@ -121,7 +129,7 @@ func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.ParseAndValidate()
+ localpart, homeserver, err = userID.ParseAndValidateStrict()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}
@@ -211,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
- return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
+ return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
+ return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@@ -229,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])
diff --git a/id/userid_test.go b/id/userid_test.go
index 359bc687..57a88066 100644
--- a/id/userid_test.go
+++ b/id/userid_test.go
@@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
-func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
-func TestUserID_ParseAndValidate_Empty(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
-func TestUserID_ParseAndValidate_Long(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
-func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
+func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidate()
+ _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
assert.NoError(t, err)
}
@@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
- parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
+ parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)
diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go
index 4be799d3..4d2bc7cf 100644
--- a/mediaproxy/mediaproxy.go
+++ b/mediaproxy/mediaproxy.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2024 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -21,8 +21,12 @@ import (
"strings"
"time"
- "github.com/gorilla/mux"
"github.com/rs/zerolog"
+ "github.com/rs/zerolog/hlog"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/ptr"
+ "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
@@ -91,9 +95,13 @@ func (d *GetMediaResponseCallback) GetContentType() string {
return d.ContentType
}
+type FileMeta struct {
+ ContentType string
+ ReplacementFile string
+}
+
type GetMediaResponseFile struct {
- Callback func(w *os.File) error
- ContentType string
+ Callback func(w *os.File) (*FileMeta, error)
}
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
@@ -108,8 +116,8 @@ type MediaProxy struct {
serverName string
serverKey *federation.SigningKey
- FederationRouter *mux.Router
- ClientMediaRouter *mux.Router
+ FederationRouter *http.ServeMux
+ ClientMediaRouter *http.ServeMux
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
@@ -117,7 +125,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
if err != nil {
return nil, err
}
- return &MediaProxy{
+ mp := &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
@@ -132,7 +140,21 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
- }, nil
+ }
+ mp.FederationRouter = http.NewServeMux()
+ mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
+ mp.ClientMediaRouter = http.NewServeMux()
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
+ mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
+ mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
+ return mp, nil
}
type BasicConfig struct {
@@ -162,8 +184,8 @@ type ServerConfig struct {
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
- router := mux.NewRouter()
- mp.RegisterRoutes(router)
+ router := http.NewServeMux()
+ mp.RegisterRoutes(router, zerolog.Nop())
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
@@ -188,39 +210,29 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder
})
}
-func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
- if mp.FederationRouter == nil {
- mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
+func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) {
+ errorBodies := exhttp.ErrorBodies{
+ NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
+ MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
}
- if mp.ClientMediaRouter == nil {
- mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
- }
-
- mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
- mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
- mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
- mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
- mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
- mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
- mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
- mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
- corsMiddleware := func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
- w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
- w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
- next.ServeHTTP(w, r)
- })
- }
- mp.ClientMediaRouter.Use(corsMiddleware)
- mp.KeyServer.Register(router)
+ router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware(
+ mp.FederationRouter,
+ exhttp.StripPrefix("/_matrix/federation"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware(
+ mp.ClientMediaRouter,
+ exhttp.StripPrefix("/_matrix/client/v1/media"),
+ hlog.NewHandler(log),
+ hlog.RequestIDHandler("request_id", "Request-Id"),
+ exhttp.CORSMiddleware,
+ requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
+ exhttp.HandleErrors(errorBodies),
+ ))
+ mp.KeyServer.Register(router, log)
}
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
@@ -234,7 +246,7 @@ func queryToMap(vals url.Values) map[string]string {
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
- mediaID := mux.Vars(r)["mediaID"]
+ mediaID := r.PathValue("mediaID")
if !id.IsValidMediaID(mediaID) {
mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
return nil
@@ -380,8 +392,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
- vars := mux.Vars(r)
- if vars["serverName"] != mp.serverName {
+ if r.PathValue("serverName") != mp.serverName {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
return
}
@@ -404,7 +415,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
- mp.addHeaders(w, mimeType, vars["fileName"])
+ mp.addHeaders(w, mimeType, r.PathValue("fileName"))
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
@@ -425,7 +436,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
defer dataResp.Reader.Close()
}
- mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"])
+ mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
if writerResp.GetContentLength() != 0 {
w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
}
@@ -447,23 +458,35 @@ func doTempFileDownload(
if err != nil {
return false, fmt.Errorf("failed to create temp file: %w", err)
}
+ origTempFile := tempFile
defer func() {
- _ = tempFile.Close()
- _ = os.Remove(tempFile.Name())
+ _ = origTempFile.Close()
+ _ = os.Remove(origTempFile.Name())
}()
- err = data.Callback(tempFile)
+ meta, err := data.Callback(tempFile)
if err != nil {
return false, err
}
- _, err = tempFile.Seek(0, io.SeekStart)
- if err != nil {
- return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ if meta.ReplacementFile != "" {
+ tempFile, err = os.Open(meta.ReplacementFile)
+ if err != nil {
+ return false, fmt.Errorf("failed to open replacement file: %w", err)
+ }
+ defer func() {
+ _ = tempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ } else {
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
}
fileInfo, err := tempFile.Stat()
if err != nil {
return false, fmt.Errorf("failed to stat temp file: %w", err)
}
- mimeType := data.ContentType
+ mimeType := meta.ContentType
if mimeType == "" {
buf := make([]byte, 512)
n, err := tempFile.Read(buf)
@@ -491,11 +514,6 @@ var (
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
WithMessage("This is a media proxy and does not support URL previews.").
WithStatus(http.StatusNotImplemented)
- ErrUnknownEndpoint = mautrix.MUnrecognized.
- WithMessage("Unrecognized endpoint")
- ErrUnsupportedMethod = mautrix.MUnrecognized.
- WithMessage("Invalid method for endpoint").
- WithStatus(http.StatusMethodNotAllowed)
)
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
@@ -505,11 +523,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request)
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
ErrPreviewURLNotSupported.Write(w)
}
-
-func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
- ErrUnknownEndpoint.Write(w)
-}
-
-func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
- ErrUnsupportedMethod.Write(w)
-}
diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go
new file mode 100644
index 00000000..507c24a5
--- /dev/null
+++ b/mockserver/mockserver.go
@@ -0,0 +1,307 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mockserver
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exerrors"
+ "go.mau.fi/util/exhttp"
+ "go.mau.fi/util/random"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto"
+ "maunium.net/go/mautrix/crypto/cryptohelper"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+func mustDecode(r *http.Request, data any) {
+ exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
+}
+
+type userAndDeviceID struct {
+ UserID id.UserID
+ DeviceID id.DeviceID
+}
+
+type MockServer struct {
+ Router *http.ServeMux
+ Server *httptest.Server
+
+ AccessTokenToUserID map[string]userAndDeviceID
+ DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
+ AccountData map[id.UserID]map[event.Type]json.RawMessage
+ DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
+ OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
+ MasterKeys map[id.UserID]mautrix.CrossSigningKeys
+ SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+ UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+
+ PopOTKs bool
+ MemoryStore bool
+}
+
+func Create(t testing.TB) *MockServer {
+ t.Helper()
+
+ server := MockServer{
+ AccessTokenToUserID: map[string]userAndDeviceID{},
+ DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
+ AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ PopOTKs: true,
+ MemoryStore: true,
+ }
+
+ router := http.NewServeMux()
+ router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
+ router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
+ router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
+ router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
+ router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
+ router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
+ router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
+ router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
+ server.Router = router
+ server.Server = httptest.NewServer(router)
+ t.Cleanup(server.Server.Close)
+ return &server
+}
+
+func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
+ authHeader := r.Header.Get("Authorization")
+ authHeader = strings.TrimPrefix(authHeader, "Bearer ")
+ userID, ok := ms.AccessTokenToUserID[authHeader]
+ if !ok {
+ panic("no user ID found for access token " + authHeader)
+ }
+ return userID
+}
+
+func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
+ exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
+}
+
+func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
+ var loginReq mautrix.ReqLogin
+ mustDecode(r, &loginReq)
+
+ deviceID := loginReq.DeviceID
+ if deviceID == "" {
+ deviceID = id.DeviceID(random.String(10))
+ }
+
+ accessToken := random.String(30)
+ userID := id.UserID(loginReq.Identifier.User)
+ ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
+ UserID: userID,
+ DeviceID: deviceID,
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
+ AccessToken: accessToken,
+ DeviceID: deviceID,
+ UserID: userID,
+ })
+}
+
+func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqSendToDevice
+ mustDecode(r, &req)
+ evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
+
+ for user, devices := range req.Messages {
+ for device, content := range devices {
+ if _, ok := ms.DeviceInbox[user]; !ok {
+ ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
+ }
+ content.ParseRaw(evtType)
+ ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
+ Sender: ms.getUserID(r).UserID,
+ Type: evtType,
+ Content: *content,
+ })
+ }
+ }
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
+ userID := id.UserID(r.PathValue("userID"))
+ eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
+
+ jsonData, _ := io.ReadAll(r.Body)
+ if _, ok := ms.AccountData[userID]; !ok {
+ ms.AccountData[userID] = map[event.Type]json.RawMessage{}
+ }
+ ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqQueryKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespQueryKeys{
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ }
+ for user := range req.DeviceKeys {
+ resp.MasterKeys[user] = ms.MasterKeys[user]
+ resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
+ resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
+ resp.DeviceKeys[user] = ms.DeviceKeys[user]
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqClaimKeys
+ mustDecode(r, &req)
+ resp := mautrix.RespClaimKeys{
+ OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
+ }
+ for user, devices := range req.OneTimeKeys {
+ resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ for device := range devices {
+ keys := ms.OneTimeKeys[user][device]
+ for keyID, key := range keys {
+ if ms.PopOTKs {
+ delete(keys, keyID)
+ }
+ resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
+ keyID: key,
+ }
+ break
+ }
+ }
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
+}
+
+func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqUploadKeys
+ mustDecode(r, &req)
+
+ uid := ms.getUserID(r)
+ userID := uid.UserID
+ if _, ok := ms.DeviceKeys[userID]; !ok {
+ ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
+ }
+ if _, ok := ms.OneTimeKeys[userID]; !ok {
+ ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
+ }
+
+ if req.DeviceKeys != nil {
+ ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
+ }
+ otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
+ if !ok {
+ otks = map[id.KeyID]mautrix.OneTimeKey{}
+ ms.OneTimeKeys[userID][uid.DeviceID] = otks
+ }
+ if req.OneTimeKeys != nil {
+ maps.Copy(otks, req.OneTimeKeys)
+ }
+
+ exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
+ OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
+ })
+}
+
+func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.UploadCrossSigningKeysReq[any]
+ mustDecode(r, &req)
+
+ userID := ms.getUserID(r).UserID
+ ms.MasterKeys[userID] = req.Master
+ ms.SelfSigningKeys[userID] = req.SelfSigning
+ ms.UserSigningKeys[userID] = req.UserSigning
+
+ ms.emptyResp(w, r)
+}
+
+func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
+ t.Helper()
+ if ctx == nil {
+ ctx = context.TODO()
+ }
+ client, err := mautrix.NewClient(ms.Server.URL, "", "")
+ require.NoError(t, err)
+ client.Client = ms.Server.Client()
+
+ _, err = client.Login(ctx, &mautrix.ReqLogin{
+ Type: mautrix.AuthTypePassword,
+ Identifier: mautrix.UserIdentifier{
+ Type: mautrix.IdentifierTypeUser,
+ User: userID.String(),
+ },
+ DeviceID: deviceID,
+ Password: "password",
+ StoreCredentials: true,
+ })
+ require.NoError(t, err)
+
+ var store any
+ if ms.MemoryStore {
+ store = crypto.NewMemoryStore(nil)
+ client.StateStore = mautrix.NewMemoryStateStore()
+ } else {
+ store, err = dbutil.NewFromConfig("", dbutil.Config{
+ PoolConfig: dbutil.PoolConfig{
+ Type: "sqlite3-fk-wal",
+ URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
+ MaxOpenConns: 5,
+ MaxIdleConns: 1,
+ },
+ }, nil)
+ require.NoError(t, err)
+ }
+ cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
+ require.NoError(t, err)
+ client.Crypto = cryptoHelper
+
+ err = cryptoHelper.Init(ctx)
+ require.NoError(t, err)
+
+ machineLog := globallog.Logger.With().
+ Stringer("my_user_id", userID).
+ Stringer("my_device_id", deviceID).
+ Logger()
+ cryptoHelper.Machine().Log = &machineLog
+
+ err = cryptoHelper.Machine().ShareKeys(ctx, 50)
+ require.NoError(t, err)
+
+ return client, cryptoHelper.Machine().CryptoStore
+}
+
+func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
+ t.Helper()
+
+ for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
+ client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
+ ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
+ }
+}
diff --git a/pushrules/action.go b/pushrules/action.go
index 9838e88b..b5a884b2 100644
--- a/pushrules/action.go
+++ b/pushrules/action.go
@@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
- action.Value, _ = val["value"]
+ action.Value = val["value"]
}
}
return nil
diff --git a/pushrules/action_test.go b/pushrules/action_test.go
index a8f68415..3c0aa168 100644
--- a/pushrules/action_test.go
+++ b/pushrules/action_test.go
@@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
err = pa.UnmarshalJSON([]byte(`9001`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`"foo"`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.PushActionType("foo"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`))
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, pushrules.ActionSetTweak, pa.Action)
assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak)
@@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data)
}
@@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.Equal(t, []byte(`"something else"`), data)
}
diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go
index 0d3eaf7a..37af3e34 100644
--- a/pushrules/condition_test.go
+++ b/pushrules/condition_test.go
@@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
-func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
- return &pushrules.PushCondition{
- Kind: pushrules.KindEventPropertyContains,
- Key: key,
- Value: value,
- }
-}
-
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),
diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go
index a531ca28..a5a0f5e7 100644
--- a/pushrules/pushrules_test.go
+++ b/pushrules/pushrules_test.go
@@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) {
},
}
pushRuleset, err := pushrules.EventToPushRules(evt)
- assert.Nil(t, err)
+ assert.NoError(t, err)
assert.NotNil(t, pushRuleset)
assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{})
diff --git a/requests.go b/requests.go
index 09e4b3cd..cc8b7266 100644
--- a/requests.go
+++ b/requests.go
@@ -2,6 +2,7 @@ package mautrix
import (
"encoding/json"
+ "fmt"
"strconv"
"time"
@@ -39,20 +40,40 @@ const (
type Direction rune
+func (d Direction) MarshalJSON() ([]byte, error) {
+ return json.Marshal(string(d))
+}
+
+func (d *Direction) UnmarshalJSON(data []byte) error {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ switch str {
+ case "f":
+ *d = DirectionForward
+ case "b":
+ *d = DirectionBackward
+ default:
+ return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str)
+ }
+ return nil
+}
+
const (
DirectionForward Direction = 'f'
DirectionBackward Direction = 'b'
)
// ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
-type ReqRegister struct {
+type ReqRegister[UIAType any] struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
InhibitLogin bool `json:"inhibit_login,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
// Type for registration, only used for appservice user registrations
// https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions
@@ -120,11 +141,12 @@ type ReqCreateRoom struct {
InitialState []*event.Event `json:"initial_state,omitempty"`
Preset string `json:"preset,omitempty"`
IsDirect bool `json:"is_direct,omitempty"`
- RoomVersion event.RoomVersion `json:"room_version,omitempty"`
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"`
MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"`
+ MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"`
BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"`
BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"`
BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"`
@@ -161,6 +183,11 @@ type ReqKnockRoom struct {
Reason string `json:"reason,omitempty"`
}
+type ReqSearchUserDirectory struct {
+ SearchTerm string `json:"search_term"`
+ Limit int `json:"limit,omitempty"`
+}
+
type ReqMutualRooms struct {
From string `json:"-"`
}
@@ -293,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
return ""
}
-type UploadCrossSigningKeysReq struct {
+type UploadCrossSigningKeysReq[UIAType any] struct {
Master CrossSigningKeys `json:"master_key"`
SelfSigning CrossSigningKeys `json:"self_signing_key"`
UserSigning CrossSigningKeys `json:"user_signing_key"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type KeyMap map[id.DeviceKeyID]string
@@ -340,18 +367,23 @@ type ReqSendToDevice struct {
}
type ReqSendEvent struct {
- Timestamp int64
- TransactionID string
- UnstableDelay time.Duration
+ Timestamp int64
+ TransactionID string
+ UnstableDelay time.Duration
+ UnstableStickyDuration time.Duration
+ DontEncrypt bool
+ MeowEventID id.EventID
+}
- DontEncrypt bool
-
- MeowEventID id.EventID
+type ReqDelayedEvents struct {
+ DelayID id.DelayID `json:"-"`
+ Status event.DelayStatus `json:"-"`
+ NextBatch string `json:"-"`
}
type ReqUpdateDelayedEvent struct {
- DelayID string `json:"-"`
- Action string `json:"action"` // TODO use enum
+ DelayID id.DelayID `json:"-"`
+ Action event.DelayAction `json:"action"`
}
// ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid
@@ -360,14 +392,14 @@ type ReqDeviceInfo struct {
}
// ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid
-type ReqDeleteDevice struct {
- Auth interface{} `json:"auth,omitempty"`
+type ReqDeleteDevice[UIAType any] struct {
+ Auth UIAType `json:"auth,omitempty"`
}
// ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices
-type ReqDeleteDevices struct {
+type ReqDeleteDevices[UIAType any] struct {
Devices []id.DeviceID `json:"devices"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type ReqPutPushRule struct {
@@ -379,18 +411,6 @@ type ReqPutPushRule struct {
Pattern string `json:"pattern"`
}
-// Deprecated: MSC2716 was abandoned
-type ReqBatchSend struct {
- PrevEventID id.EventID `json:"-"`
- BatchID id.BatchID `json:"-"`
-
- BeeperNewMessages bool `json:"-"`
- BeeperMarkReadBy id.UserID `json:"-"`
-
- StateEventsAtStart []*event.Event `json:"state_events_at_start"`
- Events []*event.Event `json:"events"`
-}
-
type ReqBeeperBatchSend struct {
// ForwardIfNoMessages should be set to true if the batch should be forward
// backfilled if there are no messages currently in the room.
@@ -586,3 +606,13 @@ func (rgr *ReqGetRelations) Query() map[string]string {
}
return query
}
+
+// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqSuspend struct {
+ Suspended bool `json:"suspended"`
+}
+
+// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type ReqLocked struct {
+ Locked bool `json:"locked"`
+}
diff --git a/responses.go b/responses.go
index 20d02af5..4fbe1fbc 100644
--- a/responses.go
+++ b/responses.go
@@ -6,12 +6,14 @@ import (
"fmt"
"maps"
"reflect"
+ "slices"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -104,11 +106,22 @@ type RespContext struct {
type RespSendEvent struct {
EventID id.EventID `json:"event_id"`
- UnstableDelayID string `json:"delay_id,omitempty"`
+ UnstableDelayID id.DelayID `json:"delay_id,omitempty"`
}
type RespUpdateDelayedEvent struct{}
+type RespDelayedEvents struct {
+ Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"`
+ Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"`
+ NextBatch string `json:"next_batch,omitempty"`
+
+ // Deprecated: Synapse implementation still returns this
+ DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"`
+ // Deprecated: Synapse implementation still returns this
+ FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"`
+}
+
type RespRedactUserEvents struct {
IsMoreEvents bool `json:"is_more_events"`
RedactedEvents struct {
@@ -210,25 +223,52 @@ func (r *RespUserProfile) MarshalJSON() ([]byte, error) {
} else {
delete(marshalMap, "avatar_url")
}
- return json.Marshal(r.Extra)
+ return json.Marshal(marshalMap)
+}
+
+type RespSearchUserDirectory struct {
+ Limited bool `json:"limited"`
+ Results []*UserDirectoryEntry `json:"results"`
+}
+
+type UserDirectoryEntry struct {
+ RespUserProfile
+ UserID id.UserID `json:"user_id"`
+}
+
+func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error {
+ err := r.RespUserProfile.UnmarshalJSON(data)
+ if err != nil {
+ return err
+ }
+ userIDStr, _ := r.Extra["user_id"].(string)
+ r.UserID = id.UserID(userIDStr)
+ delete(r.Extra, "user_id")
+ return nil
+}
+
+func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) {
+ if r.Extra == nil {
+ r.Extra = make(map[string]any)
+ }
+ r.Extra["user_id"] = r.UserID.String()
+ return r.RespUserProfile.MarshalJSON()
}
type RespMutualRooms struct {
Joined []id.RoomID `json:"joined"`
NextBatch string `json:"next_batch,omitempty"`
+ Count int `json:"count,omitempty"`
}
type RespRoomSummary struct {
PublicRoomInfo
- Membership event.Membership `json:"membership,omitempty"`
- RoomVersion event.RoomVersion `json:"room_version,omitempty"`
- Encryption id.Algorithm `json:"encryption,omitempty"`
- AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
+ Membership event.Membership `json:"membership,omitempty"`
- UnstableRoomVersion event.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
- UnstableRoomVersionOld event.RoomVersion `json:"im.nheko.summary.version,omitempty"`
- UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
+ UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
+ UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"`
+ UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
}
// RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable
@@ -302,6 +342,24 @@ type LazyLoadSummary struct {
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
}
+func (lls *LazyLoadSummary) MemberCount() int {
+ if lls == nil {
+ return 0
+ }
+ return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount)
+}
+
+func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool {
+ if lls == other {
+ return true
+ } else if lls == nil || other == nil {
+ return false
+ }
+ return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) &&
+ ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) &&
+ slices.Equal(lls.Heroes, other.Heroes)
+}
+
type SyncEventsList struct {
Events []*event.Event `json:"events,omitempty"`
}
@@ -397,7 +455,7 @@ type BeeperInboxPreviewEvent struct {
type SyncJoinedRoom struct {
Summary LazyLoadSummary `json:"summary"`
State SyncEventsList `json:"state"`
- StateAfter *SyncEventsList `json:"org.matrix.msc4222.state_after,omitempty"`
+ StateAfter *SyncEventsList `json:"state_after,omitempty"`
Timeline SyncTimeline `json:"timeline"`
Ephemeral SyncEventsList `json:"ephemeral"`
AccountData SyncEventsList `json:"account_data"`
@@ -488,30 +546,19 @@ type RespDeviceInfo struct {
LastSeenTS int64 `json:"last_seen_ts"`
}
-// Deprecated: MSC2716 was abandoned
-type RespBatchSend struct {
- StateEventIDs []id.EventID `json:"state_event_ids"`
- EventIDs []id.EventID `json:"event_ids"`
-
- InsertionEventID id.EventID `json:"insertion_event_id"`
- BatchEventID id.EventID `json:"batch_event_id"`
- BaseInsertionEventID id.EventID `json:"base_insertion_event_id"`
-
- NextBatchID id.BatchID `json:"next_batch_id"`
-}
-
type RespBeeperBatchSend struct {
EventIDs []id.EventID `json:"event_ids"`
}
// RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities
type RespCapabilities struct {
- RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
- ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
- SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
- SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
- ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
- GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
+ RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
+ ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
+ SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
+ SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
+ ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
+ UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"`
Custom map[string]interface{} `json:"-"`
}
@@ -620,6 +667,11 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool {
return available
}
+type CapUnstableAccountModeration struct {
+ Suspend bool `json:"suspend"`
+ Lock bool `json:"lock"`
+}
+
type RespPublicRooms struct {
Chunk []*PublicRoomInfo `json:"chunk"`
NextBatch string `json:"next_batch,omitempty"`
@@ -638,6 +690,10 @@ type PublicRoomInfo struct {
RoomType event.RoomType `json:"room_type"`
Topic string `json:"topic,omitempty"`
WorldReadable bool `json:"world_readable"`
+
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Encryption id.Algorithm `json:"encryption,omitempty"`
+ AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
}
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
@@ -648,12 +704,7 @@ type RespHierarchy struct {
type ChildRoomsChunk struct {
PublicRoomInfo
- ChildrenState []StrippedStateWithTime `json:"children_state"`
-}
-
-type StrippedStateWithTime struct {
- event.StrippedState
- Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
+ ChildrenState []*event.Event `json:"children_state"`
}
type RespAppservicePing struct {
@@ -716,3 +767,33 @@ type RespGetRelations struct {
PrevBatch string `json:"prev_batch,omitempty"`
RecursionDepth int `json:"recursion_depth,omitempty"`
}
+
+// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespSuspended struct {
+ Suspended bool `json:"suspended"`
+}
+
+// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
+type RespLocked struct {
+ Locked bool `json:"locked"`
+}
+
+type ConnectionInfo struct {
+ IP string `json:"ip,omitempty"`
+ LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"`
+ UserAgent string `json:"user_agent,omitempty"`
+}
+
+type SessionInfo struct {
+ Connections []ConnectionInfo `json:"connections,omitempty"`
+}
+
+type DeviceInfo struct {
+ Sessions []SessionInfo `json:"sessions,omitempty"`
+}
+
+// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
+type RespWhoIs struct {
+ UserID id.UserID `json:"user_id,omitempty"`
+ Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"`
+}
diff --git a/responses_test.go b/responses_test.go
index b23d85ad..73d82635 100644
--- a/responses_test.go
+++ b/responses_test.go
@@ -8,7 +8,6 @@ package mautrix_test
import (
"encoding/json"
- "fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
var caps mautrix.RespCapabilities
err := json.Unmarshal([]byte(sampleData), &caps)
require.NoError(t, err)
- fmt.Println(caps)
require.NotNil(t, caps.RoomVersions)
assert.Equal(t, "9", caps.RoomVersions.Default)
diff --git a/room.go b/room.go
index c3ddb7e6..4292bff5 100644
--- a/room.go
+++ b/room.go
@@ -5,8 +5,6 @@ import (
"maunium.net/go/mautrix/id"
)
-type RoomStateMap = map[event.Type]map[string]*event.Event
-
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
- stateEventMap, _ := room.State[eventType]
- evt, _ := stateEventMap[stateKey]
+ stateEventMap := room.State[eventType]
+ evt := stateEventMap[stateKey]
return evt
}
diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go
index 4a220a2b..11957dfa 100644
--- a/sqlstatestore/statestore.go
+++ b/sqlstatestore/statestore.go
@@ -62,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
+ if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
@@ -182,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID,
}
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
@@ -214,6 +222,11 @@ func (u *userProfileRow) GetMassInsertValues() [5]any {
var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ } else if userID == "" {
+ return fmt.Errorf("user ID is empty")
+ }
var nameSkeleton []byte
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
@@ -235,6 +248,9 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room
const userProfileMassInsertBatchSize = 500
func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
return store.DoTxn(ctx, nil, func(ctx context.Context) error {
err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
if err != nil {
@@ -305,6 +321,9 @@ func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.Roo
}
func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true)
ON CONFLICT (room_id) DO UPDATE SET members_fetched=true
@@ -334,6 +353,9 @@ func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID)
}
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
contentBytes, err := json.Marshal(content)
if err != nil {
return fmt.Errorf("failed to marshal content JSON: %w", err)
@@ -348,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
- QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
+ QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID).
Scan(&data)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -371,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (
}
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
@@ -379,89 +404,92 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID
}
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
+ levels = &event.PowerLevelsEventContent{}
err = store.
- QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
- Scan(&dbutil.JSON{Data: &levels})
+ QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent})
if errors.Is(err, sql.ErrNoRows) {
- err = nil
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if levels.CreateEvent != nil {
+ err = levels.CreateEvent.Content.ParseRaw(event.StateCreate)
}
return
}
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
- if store.Dialect == dbutil.Postgres {
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, userID).
- Scan(&powerLevel)
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetUserLevel(userID), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetUserLevel(userID), nil
}
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var powerLevel int
- err := store.
- QueryRow(ctx, `
- SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
- FROM mx_room_state WHERE room_id=$1
- `, roomID, eventType.Type, defaultType, defaultValue).
- Scan(&powerLevel)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- powerLevel = defaultValue
- }
- return powerLevel, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
- }
- return levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
}
+ return levels.GetEventLevel(eventType), nil
}
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
- if store.Dialect == dbutil.Postgres {
- defaultType := "events_default"
- defaultValue := 0
- if eventType.IsState() {
- defaultType = "state_default"
- defaultValue = 50
- }
- var hasPower bool
- err := store.
- QueryRow(ctx, `SELECT
- COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
- >=
- COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
- FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
- Scan(&hasPower)
- if errors.Is(err, sql.ErrNoRows) {
- err = nil
- hasPower = defaultValue == 0
- }
- return hasPower, err
- } else {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return false, err
- }
- return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return false, err
}
+ return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
+}
+
+func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ if evt.Type != event.StateCreate {
+ return fmt.Errorf("invalid event type for create event: %s", evt.Type)
+ } else if evt.RoomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event
+ `, evt.RoomID, dbutil.JSON{Data: evt})
+ return err
+}
+
+func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) {
+ err = store.
+ QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &evt})
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ if evt != nil {
+ err = evt.Content.ParseRaw(event.StateCreate)
+ }
+ return
+}
+
+func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error {
+ if roomID == "" {
+ return fmt.Errorf("room ID is empty")
+ }
+ _, err := store.Exec(ctx, `
+ INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2)
+ ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules
+ `, roomID, dbutil.JSON{Data: rules})
+ return err
+}
+
+func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) {
+ levels = &event.JoinRulesEventContent{}
+ err = store.
+ QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID).
+ Scan(&dbutil.JSON{Data: &levels})
+ if errors.Is(err, sql.ErrNoRows) {
+ levels = nil
+ err = nil
+ }
+ return
}
diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql
index a58cc56a..4679f1c6 100644
--- a/sqlstatestore/v00-latest-revision.sql
+++ b/sqlstatestore/v00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v7 (compatible with v3+): Latest revision
+-- v0 -> v10 (compatible with v3+): Latest revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
@@ -26,5 +26,7 @@ CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels jsonb,
encryption jsonb,
+ create_event jsonb,
+ join_rules jsonb,
members_fetched BOOLEAN NOT NULL DEFAULT false
);
diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql
new file mode 100644
index 00000000..9f1b55c9
--- /dev/null
+++ b/sqlstatestore/v08-create-event.sql
@@ -0,0 +1,2 @@
+-- v8 (compatible with v3+): Add create event to room state table
+ALTER TABLE mx_room_state ADD COLUMN create_event jsonb;
diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql
new file mode 100644
index 00000000..ca951068
--- /dev/null
+++ b/sqlstatestore/v09-clear-empty-room-ids.sql
@@ -0,0 +1,3 @@
+-- v9 (compatible with v3+): Clear invalid rows
+DELETE FROM mx_room_state WHERE room_id='';
+DELETE FROM mx_user_profile WHERE room_id='' OR user_id='';
diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql
new file mode 100644
index 00000000..3074c46a
--- /dev/null
+++ b/sqlstatestore/v10-join-rules.sql
@@ -0,0 +1,2 @@
+-- v10 (compatible with v3+): Add join rules to room state table
+ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb;
diff --git a/statestore.go b/statestore.go
index e728b885..2bd498dd 100644
--- a/statestore.go
+++ b/statestore.go
@@ -34,6 +34,12 @@ type StateStore interface {
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
+ SetCreate(ctx context.Context, evt *event.Event) error
+ GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error)
+
+ GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error)
+ SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error
+
HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error)
MarkMembersFetched(ctx context.Context, roomID id.RoomID) error
GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error)
@@ -68,9 +74,13 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
+ case *event.CreateEventContent:
+ err = store.SetCreate(ctx, evt)
+ case *event.JoinRulesEventContent:
+ err = store.SetJoinRules(ctx, evt.RoomID, content)
default:
switch evt.Type {
- case event.StateMember, event.StatePowerLevels, event.StateEncryption:
+ case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate:
zerolog.Ctx(ctx).Warn().
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
@@ -101,11 +111,14 @@ type MemoryStateStore struct {
MembersFetched map[id.RoomID]bool `json:"members_fetched"`
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
+ Create map[id.RoomID]*event.Event `json:"create"`
+ JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
encryptionLock sync.RWMutex
+ joinRulesLock sync.RWMutex
}
func NewMemoryStateStore() StateStore {
@@ -115,6 +128,8 @@ func NewMemoryStateStore() StateStore {
MembersFetched: make(map[id.RoomID]bool),
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
+ Create: make(map[id.RoomID]*event.Event),
+ JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent),
}
}
@@ -298,6 +313,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
+ if levels != nil && levels.CreateEvent == nil {
+ levels.CreateEvent = store.Create[roomID]
+ }
store.powerLevelsLock.RUnlock()
return
}
@@ -314,6 +332,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
+func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
+ store.powerLevelsLock.Lock()
+ store.Create[evt.RoomID] = evt
+ if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil {
+ pls.CreateEvent = evt
+ }
+ store.powerLevelsLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) {
+ store.powerLevelsLock.RLock()
+ evt := store.Create[roomID]
+ store.powerLevelsLock.RUnlock()
+ return evt, nil
+}
+
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
@@ -327,6 +362,19 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R
return store.Encryption[roomID], nil
}
+func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error {
+ store.joinRulesLock.Lock()
+ store.JoinRules[roomID] = content
+ store.joinRulesLock.Unlock()
+ return nil
+}
+
+func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) {
+ store.joinRulesLock.RLock()
+ defer store.joinRulesLock.RUnlock()
+ return store.JoinRules[roomID], nil
+}
+
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go
index a09ba174..0925b748 100644
--- a/synapseadmin/roomapi.go
+++ b/synapseadmin/roomapi.go
@@ -75,8 +75,7 @@ type RespListRooms struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
var resp RespListRooms
- var reqURL string
- reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
+ reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
_, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
@@ -117,6 +116,7 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to
type ReqDeleteRoom struct {
Purge bool `json:"purge,omitempty"`
+ ForcePurge bool `json:"force_purge,omitempty"`
Block bool `json:"block,omitempty"`
Message string `json:"message,omitempty"`
RoomName string `json:"room_name,omitempty"`
diff --git a/sync.go b/sync.go
index 9a2b9edf..598df8e0 100644
--- a/sync.go
+++ b/sync.go
@@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
}
}()
+ ctx = context.WithValue(ctx, SyncTokenContextKey, since)
for _, listener := range s.syncListeners {
if !listener(ctx, res, since) {
@@ -263,7 +264,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
- var inviteState []event.StrippedState
+ var inviteState []*event.Event
var inviteEvt *event.Event
for _, evt := range meta.State.Events {
if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() {
@@ -271,12 +272,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string
} else {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
- inviteState = append(inviteState, event.StrippedState{
- Content: evt.Content,
- Type: evt.Type,
- StateKey: evt.GetStateKey(),
- Sender: evt.Sender,
- })
+ inviteState = append(inviteState, evt)
}
}
if inviteEvt != nil {
diff --git a/url.go b/url.go
index d888956a..91b3d49d 100644
--- a/url.go
+++ b/url.go
@@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
- if urlQuery != nil {
- for k, v := range urlQuery {
- q.Set(k, v)
- }
+ for k, v := range urlQuery {
+ q.Set(k, v)
}
})
}
diff --git a/version.go b/version.go
index 6b8af5ef..f00bbf39 100644
--- a/version.go
+++ b/version.go
@@ -4,10 +4,11 @@ import (
"fmt"
"regexp"
"runtime"
+ "runtime/debug"
"strings"
)
-const Version = "v0.24.2"
+const Version = "v0.26.3"
var GoModVersion = ""
var Commit = ""
@@ -15,11 +16,20 @@ var VersionWithCommit = Version
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
-var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
-
func init() {
+ if GoModVersion == "" {
+ info, _ := debug.ReadBuildInfo()
+ if info != nil {
+ for _, mod := range info.Deps {
+ if mod.Path == "maunium.net/go/mautrix" {
+ GoModVersion = mod.Version
+ break
+ }
+ }
+ }
+ }
if GoModVersion != "" {
- match := goModVersionRegex.FindStringSubmatch(GoModVersion)
+ match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion)
if match != nil {
Commit = match[1]
}
diff --git a/versions.go b/versions.go
index f87bddda..61b2e4ea 100644
--- a/versions.go
+++ b/versions.go
@@ -60,20 +60,28 @@ type UnstableFeature struct {
}
var (
- FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
- FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
- FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
- FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
- FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
- FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
+ FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
+ FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
+ FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
+ FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
+ FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
+ FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
+ FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
+ FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
- BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
- BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
- BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
- BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
- BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
- BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
- BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
+ BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
+ BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
+ BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
+ BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
+ BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
+ BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
+ BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"}
+ BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
@@ -117,6 +125,8 @@ var (
SpecV113 = MustParseSpecVersion("v1.13")
SpecV114 = MustParseSpecVersion("v1.14")
SpecV115 = MustParseSpecVersion("v1.15")
+ SpecV116 = MustParseSpecVersion("v1.16")
+ SpecV117 = MustParseSpecVersion("v1.17")
)
func (svf SpecVersionFormat) String() string {