diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index c0add220..71c1988b 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -10,12 +10,12 @@ jobs:
runs-on: ubuntu-latest
name: Lint (latest)
steps:
- - uses: actions/checkout@v6
+ - uses: actions/checkout@v4
- name: Set up Go
- uses: actions/setup-go@v6
+ uses: actions/setup-go@v5
with:
- go-version: "1.26"
+ go-version: "1.24"
cache: true
- name: Install libolm
@@ -24,7 +24,6 @@ jobs:
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
- go install honnef.co/go/tools/cmd/staticcheck@latest
export PATH="$HOME/go/bin:$PATH"
- name: Run pre-commit
@@ -35,14 +34,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- go-version: ["1.25", "1.26"]
- name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm)
+ go-version: ["1.23", "1.24"]
+ name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, libolm)
steps:
- - uses: actions/checkout@v6
+ - uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
- uses: actions/setup-go@v6
+ uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: true
@@ -61,28 +60,28 @@ jobs:
- name: Test
run: go test -json -v ./... 2>&1 | gotestfmt
- - name: Test (jsonv2)
- env:
- GOEXPERIMENT: jsonv2
- run: go test -json -v ./... 2>&1 | gotestfmt
-
build-goolm:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- go-version: ["1.25", "1.26"]
- name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm)
+ go-version: ["1.23", "1.24"]
+ name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, goolm)
steps:
- - uses: actions/checkout@v6
+ - uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
- uses: actions/setup-go@v6
+ uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: true
+ - name: Set up gotestfmt
+ uses: GoTestTools/gotestfmt-action@v2
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+
- name: Build
run: |
rm -rf crypto/libolm
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 9a9e7375..578349c9 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -17,7 +17,7 @@ jobs:
lock-stale:
runs-on: ubuntu-latest
steps:
- - uses: dessant/lock-threads@v6
+ - uses: dessant/lock-threads@v5
id: lock
with:
issue-inactive-days: 90
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 616fccb2..81701203 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v6.0.0
+ rev: v5.0.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@@ -9,7 +9,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/tekwizely/pre-commit-golang
- rev: v1.0.0-rc.4
+ rev: v1.0.0-rc.1
hooks:
- id: go-imports-repo
args:
@@ -18,7 +18,8 @@ repos:
- "-w"
- id: go-vet-repo-mod
- id: go-mod-tidy
- - id: go-staticcheck-repo-mod
+ # TODO enable this
+ #- id: go-staticcheck-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.4.2
@@ -26,4 +27,3 @@ repos:
- id: prevent-literal-http-methods
- id: zerolog-ban-global-log
- id: zerolog-ban-msgf
- - id: zerolog-use-stringer
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f2829199..8e71381e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,237 +1,3 @@
-## v0.26.3 (2026-02-16)
-
-* Bumped minimum Go version to 1.25.
-* *(client)* Added fields for sending [MSC4354] sticky events.
-* *(bridgev2)* Added automatic message request accepting when sending message.
-* *(mediaproxy)* Added support for federation thumbnail endpoint.
-* *(crypto/ssss)* Improved support for recovery keys with slightly broken
- metadata.
-* *(crypto)* Changed key import to call session received callback even for
- sessions that already exist in the database.
-* *(appservice)* Fixed building websocket URL accidentally using file path
- separators instead of always `/`.
-* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field.
-* *(client)* Fixed incorrect context usage in async uploads.
-* *(crypto)* Fixed panic when passing invalid input to megolm message index
- parser used for debugging.
-* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned
- up properly.
-
-[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354
-
-## v0.26.2 (2026-01-16)
-
-* *(bridgev2)* Added chunked portal deletion to avoid database locks when
- deleting large portals.
-* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata
- as per [MSC4392].
-* *(bridgev2/login)* Added `default_value` for user input fields.
-* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested
- HTTP client settings and to reset active connections of the network connector.
-* *(bridgev2)* Added interface to let network connectors get the provisioning
- API HTTP router and add new endpoints.
-* *(event)* Added blurhash field to Beeper link preview objects.
-* *(event)* Added [MSC4391] support for bot commands.
-* *(event)* Dropped [MSC4332] support for bot commands.
-* *(client)* Changed media download methods to return an error if the provided
- MXC URI is empty.
-* *(client)* Stabilized support for [MSC4323].
-* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events.
-* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with
- a portal re-ID call.
-
-[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391
-[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392
-
-## v0.26.1 (2025-12-16)
-
-* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return
- the mime type from the callback rather than in the return get media return
- value. The callback can now also redirect the caller to a different file.
-* *(federation)* Added join/knock/leave functions
- (thanks to [@nexy7574] in [#422]).
-* *(federation/eventauth)* Fixed various incorrect checks.
-* *(client)* Added backoff for retrying media uploads to external URLs
- (with MSC3870).
-* *(bridgev2/config)* Added support for overriding config fields using
- environment variables.
-* *(bridgev2/commands)* Added command to mute chat on remote network.
-* *(bridgev2)* Added interface for network connectors to redirect to a different
- user ID when handling an invite from Matrix.
-* *(bridgev2)* Added interface for signaling message request status of portals.
-* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag
- is set in chat info.
-* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if
- bridging the new one is successful.
-* *(bridgev2/mxmain)* Improved error message when trying to run bridge with
- pre-megabridge database when no database migration exists.
-* *(bridgev2)* Improved reliability of database migration when enabling split
- portals.
-* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats.
-* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public
- portal rooms.
-* *(bridgev2)* Stopped hardcoding room versions in favor of checking
- server capabilities to determine appropriate `/createRoom` parameters.
-
-[#422]: https://github.com/mautrix/go/pull/422
-
-## v0.26.0 (2025-11-16)
-
-* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent`
- has been able to do the same for a while now.
-* *(client,federation)* Added size limits for responses to make it safer to send
- requests to untrusted servers.
-* *(client)* Added wrapper for `/admin/whois` client API
- (thanks to [@nexy7574] in [#411]).
-* *(synapseadmin)* Added `force_purge` option to DeleteRoom
- (thanks to [@nexy7574] in [#420]).
-* *(statestore)* Added saving join rules for rooms.
-* *(bridgev2)* Added optional automatic rollback of room state if bridging the
- change to the remote network fails.
-* *(bridgev2)* Added management room notices if transient disconnect state
- doesn't resolve within 3 minutes.
-* *(bridgev2)* Added interface to signal that certain participants couldn't be
- invited when creating a group.
-* *(bridgev2)* Added `select` type for user input fields in login.
-* *(bridgev2)* Added interface to let network connector customize personal
- filtering space.
-* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to
- other bots.
-* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever
- possible.
-* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names,
- and encrypted files.
-* *(bridgev2/commands)* Added command to resync a single portal.
-* *(bridgev2/commands)* Added create group command.
-* *(bridgev2/config)* Added option to limit maximum number of logins.
-* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room
- is public.
-* *(bridgev2/disappear)* Changed read receipt handling to only start
- disappearing timers for messages up to the read message (note: may not work in
- all cases if the read receipt points at an unknown event).
-* *(event/reply)* Changed plaintext reply fallback removal to only happen when
- an HTML reply fallback is removed successfully.
-* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run.
-* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages.
-* *(federation)* Fixed HTTP method for sending transactions
- (thanks to [@nexy7574] in [#426]).
-* *(federation)* Fixed response body being closed even when using `DontReadBody`
- parameter.
-* *(federation)* Fixed validating auth for requests with query params.
-* *(federation/eventauth)* Fixed typo causing restricted joins to not work.
-
-[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169
-[#411]: github.com/mautrix/go/pull/411
-[#420]: github.com/mautrix/go/pull/420
-[#426]: github.com/mautrix/go/pull/426
-
-## v0.25.2 (2025-10-16)
-
-* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into
- `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old
- behavior, but most users likely want the relaxed version, as there are real
- users whose user IDs aren't valid under the strict rules.
-* *(crypto)* Added helper methods for generating and verifying with recovery
- keys.
-* *(bridgev2/matrix)* Added config option to automatically generate a recovery
- key for the bridge bot and self-sign the bridge's device.
-* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode
- for encryption with standard servers like Synapse.
-* *(bridgev2)* Added optional support for implicit read receipts.
-* *(bridgev2)* Added interface for deleting chats on remote network.
-* *(bridgev2)* Added local enforcement of media duration and size limits.
-* *(bridgev2)* Extended event duration logging to log any event taking too long.
-* *(bridgev2)* Improved validation in group creation provisioning API.
-* *(event)* Added event type constant for poll end events.
-* *(client)* Added wrapper for searching user directory.
-* *(client)* Improved support for managing [MSC4140] delayed events.
-* *(crypto/helper)* Changed default sync handling to not block on waiting for
- decryption keys. On initial sync, keys won't be requested at all by default.
-* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1).
-* *(bridgev2)* Fixed various bugs with migrating to split portals.
-* *(event)* Fixed poll start events having incorrect null `m.relates_to`.
-* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling.
-* *(federation)* Fixed various bugs in event auth.
-
-## v0.25.1 (2025-09-16)
-
-* *(client)* Fixed HTTP method of delete devices API call
- (thanks to [@fmseals] in [#393]).
-* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints
- (thanks to [@nexy7574] in [#407]).
-* *(client)* Stabilized support for extensible profiles.
-* *(client)* Stabilized support for `state_after` in sync.
-* *(client)* Removed deprecated MSC2716 requests.
-* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if
- the content struct doesn't implement `Relatable`.
-* *(crypto)* Changed olm unwedging to ignore newly created sessions if they
- haven't been used successfully in either direction.
-* *(federation)* Added utilities for generating, parsing, validating and
- authorizing PDUs.
- * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2`
-* *(event)* Added `is_animated` flag from [MSC4230] to file info.
-* *(event)* Added types for [MSC4332]: In-room bot commands.
-* *(event)* Added missing poll end event type for [MSC3381].
-* *(appservice)* Fixed URLs not being escaped properly when using unix socket
- for homeserver connections.
-* *(format)* Added more helpers for forming markdown links.
-* *(event,bridgev2)* Added support for Beeper's disappearing message state event.
-* *(bridgev2)* Redesigned group creation interface and added support in commands
- and provisioning API.
-* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to
- get an old event. The method is best effort only, as some configurations don't
- allow fetching old events.
-* *(bridgev2)* Added shared logic for provisioning that can be reused by the
- API, commands and other sources.
-* *(bridgev2)* Fixed mentions and URL previews not being copied over when
- caption and media are merged.
-* *(bridgev2)* Removed config option to change provisioning API prefix, which
- had already broken in the previous release.
-
-[@fmseals]: https://github.com/fmseals
-[#393]: https://github.com/mautrix/go/pull/393
-[#407]: https://github.com/mautrix/go/pull/407
-[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
-[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230
-[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323
-[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332
-
-## v0.25.0 (2025-08-16)
-
-* Bumped minimum Go version to 1.24.
-* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux
- with standard library ServeMux.
-* *(client,bridgev2)* Added support for creator power in room v12.
-* *(client)* Added option to not set `User-Agent` header for improved Wasm
- compatibility.
-* *(bridgev2)* Added support for following tombstones.
-* *(bridgev2)* Added interface for getting arbitrary state event from Matrix.
-* *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't
- use too many resources even if there are a large number of messages.
-* *(bridgev2/commands)* Added support for canceling QR login with `cancel`
- command.
-* *(client)* Added option to override HTTP client used for .well-known
- resolution.
-* *(crypto/backup)* Added method for encrypting key backup session without
- private keys.
-* *(event->id)* Moved room version type and constants to id package.
-* *(bridgev2)* Bots in DM portals will now be added to the functional members
- state event to hide them from the room name calculation.
-* *(bridgev2)* Changed message delete handling to ignore "delete for me" events
- if there are multiple Matrix users in the room.
-* *(format/htmlparser)* Changed text processing to collapse multiple spaces into
- one when outside `pre`/`code` tags.
-* *(format/htmlparser)* Removed link suffix in plaintext output when link text
- is only missing protocol part of href.
- * e.g. `example.com` will turn into
- `example.com` rather than `example.com (https://example.com)`
-* *(appservice)* Switched appservice websockets from gorilla/websocket to
- coder/websocket.
-* *(bridgev2/matrix)* Fixed encryption key sharing not ignoring ghosts properly.
-* *(crypto/attachments)* Fixed hash check when decrypting file streams.
-* *(crypto)* Removed unnecessary `AlreadyShared` error in `ShareGroupSession`.
- The function will now act as if it was successful instead.
-
## v0.24.2 (2025-07-16)
* *(bridgev2)* Added support for return values from portal event handlers. Note
@@ -437,7 +203,6 @@
[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156
[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190
[#288]: https://github.com/mautrix/go/pull/288
-[@onestacked]: https://github.com/onestacked
## v0.22.0 (2024-11-16)
diff --git a/README.md b/README.md
index b1a2edf8..ac41ca78 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,8 @@
# mautrix-go
[](https://pkg.go.dev/maunium.net/go/mautrix)
-A Golang Matrix framework. Used by [gomuks](https://gomuks.app),
-[go-neb](https://github.com/matrix-org/go-neb),
-[mautrix-whatsapp](https://github.com/mautrix/whatsapp)
+A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks),
+[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp)
and others.
Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net)
@@ -14,10 +13,9 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or
In addition to the basic client API features the original project has, this framework also has:
* Appservice support (Intent API like mautrix-python, room state storage, etc)
-* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc)
+* End-to-end encryption support (incl. interactive SAS verification)
* High-level module for building puppeting bridges
-* Partial federation module (making requests, PDU processing and event authorization)
-* A media proxy server which can be used to expose anything as a Matrix media repo
+* High-level module for building chat clients
* Wrapper functions for the Synapse admin API
* Structs for parsing event content
* Helpers for parsing and generating Matrix HTML
diff --git a/appservice/appservice.go b/appservice/appservice.go
index d7037ef6..518e1073 100644
--- a/appservice/appservice.go
+++ b/appservice/appservice.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -19,7 +19,8 @@ import (
"syscall"
"time"
- "github.com/coder/websocket"
+ "github.com/gorilla/mux"
+ "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
"gopkg.in/yaml.v3"
@@ -42,7 +43,7 @@ func Create() *AppService {
intents: make(map[id.UserID]*IntentAPI),
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
StateStore: mautrix.NewMemoryStateStore().(StateStore),
- Router: http.NewServeMux(),
+ Router: mux.NewRouter(),
UserAgent: mautrix.DefaultUserAgent,
txnIDC: NewTransactionIDCache(128),
Live: true,
@@ -60,12 +61,12 @@ func Create() *AppService {
DefaultHTTPRetries: 4,
}
- as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
- as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
- as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
- as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
- as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
- as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
+ as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
+ as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
+ as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
return as
}
@@ -113,13 +114,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil)
// QueryHandler handles room alias and user ID queries from the homeserver.
type QueryHandler interface {
- QueryAlias(alias id.RoomAlias) bool
+ QueryAlias(alias string) bool
QueryUser(userID id.UserID) bool
}
type QueryHandlerStub struct{}
-func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool {
+func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
return false
}
@@ -127,7 +128,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
return false
}
-type WebsocketHandler func(WebsocketCommand) (ok bool, data any)
+type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
type StateStore interface {
mautrix.StateStore
@@ -159,7 +160,7 @@ type AppService struct {
QueryHandler QueryHandler
StateStore StateStore
- Router *http.ServeMux
+ Router *mux.Router
UserAgent string
server *http.Server
HTTPClient *http.Client
@@ -178,6 +179,7 @@ type AppService struct {
intentsLock sync.RWMutex
ws *websocket.Conn
+ wsWriteLock sync.Mutex
StopWebsocket func(error)
websocketHandlers map[string]WebsocketHandler
websocketHandlersLock sync.RWMutex
@@ -334,7 +336,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error {
} else if as.hsURLForClient.Scheme == "" {
as.hsURLForClient.Scheme = "https"
}
- as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath()
+ as.hsURLForClient.RawPath = parsedURL.EscapedPath()
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar}
@@ -360,7 +362,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
AccessToken: as.Registration.AppToken,
UserAgent: as.UserAgent,
StateStore: as.StateStore,
- Log: as.Log.With().Stringer("as_user_id", userID).Logger(),
+ Log: as.Log.With().Str("as_user_id", userID.String()).Logger(),
Client: as.HTTPClient,
DefaultHTTPRetries: as.DefaultHTTPRetries,
SpecVersions: as.SpecVersions,
diff --git a/appservice/http.go b/appservice/http.go
index 27ce6288..1ebe6e56 100644
--- a/appservice/http.go
+++ b/appservice/http.go
@@ -17,6 +17,7 @@ import (
"syscall"
"time"
+ "github.com/gorilla/mux"
"github.com/rs/zerolog"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exstrings"
@@ -94,7 +95,8 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
return
}
- txnID := r.PathValue("txnID")
+ vars := mux.Vars(r)
+ txnID := vars["txnID"]
if len(txnID) == 0 {
mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w)
return
@@ -201,7 +203,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
}
err := evt.Content.ParseRaw(evt.Type)
if errors.Is(err, event.ErrUnsupportedContentType) {
- log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event")
+ log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event")
} else if err != nil {
log.Warn().Err(err).
Str("event_id", evt.ID.String()).
@@ -238,7 +240,8 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
return
}
- roomAlias := id.RoomAlias(r.PathValue("roomAlias"))
+ vars := mux.Vars(r)
+ roomAlias := vars["roomAlias"]
ok := as.QueryHandler.QueryAlias(roomAlias)
if ok {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
@@ -253,7 +256,8 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
return
}
- userID := id.UserID(r.PathValue("userID"))
+ vars := mux.Vars(r)
+ userID := id.UserID(vars["userID"])
ok := as.QueryHandler.QueryUser(userID)
if ok {
exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
diff --git a/appservice/intent.go b/appservice/intent.go
index 5d43f190..d6cda137 100644
--- a/appservice/intent.go
+++ b/appservice/intent.go
@@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
}
func (intent *IntentAPI) Register(ctx context.Context) error {
- _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{
+ _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{
Username: intent.Localpart,
Type: mautrix.AuthTypeAppservice,
InhibitLogin: true,
@@ -86,7 +86,6 @@ func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error {
type EnsureJoinedParams struct {
IgnoreCache bool
BotOverride *mautrix.Client
- Via []string
}
func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error {
@@ -100,17 +99,11 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
return nil
}
- err := intent.EnsureRegistered(ctx)
- if err != nil {
+ if err := intent.EnsureRegistered(ctx); err != nil {
return fmt.Errorf("failed to ensure joined: %w", err)
}
- var resp *mautrix.RespJoinRoom
- if len(params.Via) > 0 {
- resp, err = intent.JoinRoom(ctx, roomID.String(), &mautrix.ReqJoinRoom{Via: params.Via})
- } else {
- resp, err = intent.JoinRoomByID(ctx, roomID)
- }
+ resp, err := intent.JoinRoomByID(ctx, roomID)
if err != nil {
bot := intent.bot
if params.BotOverride != nil {
@@ -214,31 +207,23 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any {
}
}
-func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
+func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
- return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...)
+ return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON)
}
-func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
- if err := intent.EnsureJoined(ctx, roomID); err != nil {
- return nil, err
- }
- if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
- return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
- }
- contentJSON = intent.AddDoublePuppetValue(contentJSON)
- return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...)
-}
-
-// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead
func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
- return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
+ if err := intent.EnsureJoined(ctx, roomID); err != nil {
+ return nil, err
+ }
+ contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
+ return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
}
-func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
+func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
if eventType != event.StateMember || stateKey != string(intent.UserID) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
@@ -247,12 +232,15 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
- return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...)
+ return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON)
}
-// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead
func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
- return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
+ if err := intent.EnsureJoined(ctx, roomID); err != nil {
+ return nil, err
+ }
+ contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
+ return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts)
}
func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
@@ -311,7 +299,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
- return &mautrix.RespJoinRoom{RoomID: roomID}, err
+ return &mautrix.RespJoinRoom{}, err
}
return intent.Client.JoinRoomByID(ctx, roomID)
}
@@ -380,24 +368,6 @@ func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id
return member
}
-func (intent *IntentAPI) FillPowerLevelCreateEvent(ctx context.Context, roomID id.RoomID, pl *event.PowerLevelsEventContent) error {
- if pl.CreateEvent != nil {
- return nil
- }
- var err error
- pl.CreateEvent, err = intent.StateStore.GetCreate(ctx, roomID)
- if err != nil {
- return fmt.Errorf("failed to get create event from cache: %w", err)
- } else if pl.CreateEvent != nil {
- return nil
- }
- pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "")
- if err != nil {
- return fmt.Errorf("failed to get create event from server: %w", err)
- }
- return nil
-}
-
func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID)
if err != nil {
@@ -407,12 +377,6 @@ func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl
if pl == nil {
pl = &event.PowerLevelsEventContent{}
err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl)
- if err != nil {
- return
- }
- }
- if pl.CreateEvent == nil {
- pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "")
}
return
}
@@ -427,7 +391,8 @@ func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, us
return nil, err
}
- if pl.EnsureUserLevelAs(intent.UserID, userID, level) {
+ if pl.GetUserLevel(userID) != level {
+ pl.SetUserLevel(userID, level)
return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl)
}
return nil, nil
@@ -516,7 +481,7 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU
// No need to update
return nil
}
- if !avatarURL.IsEmpty() && !intent.SpecVersions.Supports(mautrix.BeeperFeatureHungry) {
+ if !avatarURL.IsEmpty() {
// Some homeservers require the avatar to be downloaded before setting it
resp, _ := intent.Download(ctx, avatarURL)
if resp != nil {
diff --git a/appservice/websocket.go b/appservice/websocket.go
index ef65e65a..3d5bd232 100644
--- a/appservice/websocket.go
+++ b/appservice/websocket.go
@@ -11,15 +11,15 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"net/http"
"net/url"
- "path"
+ "path/filepath"
"strings"
"sync"
"sync/atomic"
+ "time"
- "github.com/coder/websocket"
+ "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -28,9 +28,11 @@ import (
)
type WebsocketRequest struct {
- ReqID int `json:"id,omitempty"`
- Command string `json:"command"`
- Data any `json:"data"`
+ ReqID int `json:"id,omitempty"`
+ Command string `json:"command"`
+ Data interface{} `json:"data"`
+
+ Deadline time.Duration `json:"-"`
}
type WebsocketCommand struct {
@@ -41,7 +43,7 @@ type WebsocketCommand struct {
Ctx context.Context `json:"-"`
}
-func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
+func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
return nil
}
@@ -56,7 +58,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
var prefixMessage string
for unwrappedErr != nil {
errorData, jsonErr = json.Marshal(unwrappedErr)
- if len(errorData) > 2 && jsonErr == nil {
+ if errorData != nil && len(errorData) > 2 && jsonErr == nil {
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
prefixMessage = strings.TrimRight(prefixMessage, ": ")
break
@@ -98,8 +100,8 @@ type WebsocketMessage struct {
}
const (
- WebsocketCloseConnReplaced websocket.StatusCode = 4001
- WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002
+ WebsocketCloseConnReplaced = 4001
+ WebsocketCloseTxnNotAcknowledged = 4002
)
type MeowWebsocketCloseCode string
@@ -133,7 +135,7 @@ func (mwcc MeowWebsocketCloseCode) String() string {
}
type CloseCommand struct {
- Code websocket.StatusCode `json:"-"`
+ Code int `json:"-"`
Command string `json:"command"`
Status MeowWebsocketCloseCode `json:"status"`
}
@@ -143,15 +145,15 @@ func (cc CloseCommand) Error() string {
}
func parseCloseError(err error) error {
- var closeError websocket.CloseError
+ closeError := &websocket.CloseError{}
if !errors.As(err, &closeError) {
return err
}
var closeCommand CloseCommand
closeCommand.Code = closeError.Code
closeCommand.Command = "disconnect"
- if len(closeError.Reason) > 0 {
- jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand)
+ if len(closeError.Text) > 0 {
+ jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
if jsonErr != nil {
return err
}
@@ -159,7 +161,7 @@ func parseCloseError(err error) error {
if len(closeCommand.Status) == 0 {
if closeCommand.Code == WebsocketCloseConnReplaced {
closeCommand.Status = MeowConnectionReplaced
- } else if closeCommand.Code == websocket.StatusServiceRestart {
+ } else if closeCommand.Code == websocket.CloseServiceRestart {
closeCommand.Status = MeowServerShuttingDown
}
}
@@ -170,23 +172,20 @@ func (as *AppService) HasWebsocket() bool {
return as.ws != nil
}
-func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error {
+func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
ws := as.ws
if cmd == nil {
return nil
} else if ws == nil {
return ErrWebsocketNotConnected
}
- wr, err := ws.Writer(ctx, websocket.MessageText)
- if err != nil {
- return err
+ as.wsWriteLock.Lock()
+ defer as.wsWriteLock.Unlock()
+ if cmd.Deadline == 0 {
+ cmd.Deadline = 3 * time.Minute
}
- err = json.NewEncoder(wr).Encode(cmd)
- if err != nil {
- _ = wr.Close()
- return err
- }
- return wr.Close()
+ _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
+ return ws.WriteJSON(cmd)
}
func (as *AppService) clearWebsocketResponseWaiters() {
@@ -223,12 +222,12 @@ func (er *ErrorResponse) Error() string {
return fmt.Sprintf("%s: %s", er.Code, er.Message)
}
-func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error {
+func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
respChan := make(chan *WebsocketCommand, 1)
as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
- err := as.SendWebsocket(ctx, cmd)
+ err := as.SendWebsocket(cmd)
if err != nil {
return err
}
@@ -257,7 +256,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques
}
}
-func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) {
+func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command")
return false, fmt.Errorf("unknown request type")
}
@@ -281,28 +280,14 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg
return true, &WebsocketTransactionResponse{TxnID: msg.TxnID}
}
-func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) {
+func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
defer stopFunc(ErrWebsocketUnknownError)
+ ctx := context.Background()
for {
- msgType, reader, err := ws.Reader(ctx)
- if err != nil {
- as.Log.Debug().Err(err).Msg("Error getting reader from websocket")
- stopFunc(parseCloseError(err))
- return
- } else if msgType != websocket.MessageText {
- as.Log.Debug().Msg("Ignoring non-text message from websocket")
- continue
- }
- data, err := io.ReadAll(reader)
- if err != nil {
- as.Log.Debug().Err(err).Msg("Error reading data from websocket")
- stopFunc(parseCloseError(err))
- return
- }
var msg WebsocketMessage
- err = json.Unmarshal(data, &msg)
+ err := ws.ReadJSON(&msg)
if err != nil {
- as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket")
+ as.Log.Debug().Err(err).Msg("Error reading from websocket")
stopFunc(parseCloseError(err))
return
}
@@ -313,11 +298,11 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error)
with = with.Str("transaction_id", msg.TxnID)
}
log := with.Logger()
- ctx := log.WithContext(ctx)
+ ctx = log.WithContext(ctx)
if msg.Command == "" || msg.Command == "transaction" {
ok, resp := as.WebsocketTransactionHandler(ctx, msg)
go func() {
- err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp))
+ err := as.SendWebsocket(msg.MakeResponse(ok, resp))
if err != nil {
log.Warn().Err(err).Msg("Failed to send response to websocket transaction")
} else {
@@ -349,7 +334,7 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error)
}
go func() {
okResp, data := handler(msg.WebsocketCommand)
- err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data))
+ err := as.SendWebsocket(msg.MakeResponse(okResp, data))
if err != nil {
log.Error().Err(err).Msg("Failed to send response to websocket command")
} else if okResp {
@@ -362,7 +347,7 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error)
}
}
-func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error {
+func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
var parsed *url.URL
if baseURL != "" {
var err error
@@ -374,21 +359,18 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
copiedURL := *as.hsURLForClient
parsed = &copiedURL
}
- parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
+ parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
if parsed.Scheme == "http" {
parsed.Scheme = "ws"
} else if parsed.Scheme == "https" {
parsed.Scheme = "wss"
}
- ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{
- HTTPClient: as.HTTPClient,
- HTTPHeader: http.Header{
- "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
- "User-Agent": []string{as.BotClient().UserAgent},
+ ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
+ "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
+ "User-Agent": []string{as.BotClient().UserAgent},
- "X-Mautrix-Process-ID": []string{as.ProcessID},
- "X-Mautrix-Websocket-Version": []string{"3"},
- },
+ "X-Mautrix-Process-ID": []string{as.ProcessID},
+ "X-Mautrix-Websocket-Version": []string{"3"},
})
if resp != nil && resp.StatusCode >= 400 {
var errResp mautrix.RespError
@@ -419,13 +401,12 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
}
})
}
- ws.SetReadLimit(50 * 1024 * 1024)
as.ws = ws
as.StopWebsocket = stopFunc
as.PrepareWebsocket()
as.Log.Debug().Msg("Appservice transaction websocket opened")
- go as.consumeWebsocket(ctx, stopFunc, ws)
+ go as.consumeWebsocket(stopFunc, ws)
var onConnectDone atomic.Bool
if onConnect != nil {
@@ -447,7 +428,12 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
as.ws = nil
}
- err = ws.Close(websocket.StatusGoingAway, "")
+ _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
+ err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
+ if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
+ as.Log.Warn().Err(err).Msg("Error writing close message to websocket")
+ }
+ err = ws.Close()
if err != nil {
as.Log.Warn().Err(err).Msg("Error closing websocket")
}
diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go
index 226adc90..a4ce033e 100644
--- a/bridgev2/bridge.go
+++ b/bridgev2/bridge.go
@@ -9,14 +9,11 @@ package bridgev2
import (
"context"
"fmt"
- "os"
"sync"
- "sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
- "go.mau.fi/util/exhttp"
"go.mau.fi/util/exsync"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
@@ -54,7 +51,6 @@ type Bridge struct {
Background bool
ExternallyManagedDB bool
- stopping atomic.Bool
wakeupBackfillQueue chan struct{}
stopBackfillQueue *exsync.Event
@@ -124,13 +120,12 @@ func (br *Bridge) Start(ctx context.Context) error {
if err != nil {
return err
}
- go br.PostStart(ctx)
+ br.PostStart(ctx)
return nil
}
func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error {
br.Background = true
- br.stopping.Store(false)
err := br.StartConnectors(ctx)
if err != nil {
return err
@@ -166,7 +161,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
case <-time.After(20 * time.Second):
case <-ctx.Done():
}
- br.stopping.Store(true)
return nil
} else {
br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode")
@@ -176,7 +170,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
func (br *Bridge) StartConnectors(ctx context.Context) error {
br.Log.Info().Msg("Starting bridge")
- br.stopping.Store(false)
if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil {
br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background())
br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx)
@@ -189,11 +182,7 @@ func (br *Bridge) StartConnectors(ctx context.Context) error {
}
}
if !br.Background {
- var postMigrate func()
- br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx)
- if postMigrate != nil {
- defer postMigrate()
- }
+ br.didSplitPortals = br.MigrateToSplitPortals(ctx)
}
br.Log.Info().Msg("Starting Matrix connector")
err := br.Matrix.Start(ctx)
@@ -282,64 +271,20 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps b
Msg("Resent bridge info to all portals")
}
-func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) {
+func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool {
log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger()
ctx = log.WithContext(ctx)
if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" {
- return false, nil
+ return false
}
affected, err := br.DB.Portal.MigrateToSplitPortals(ctx)
if err != nil {
- log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals")
- os.Exit(31)
- return false, nil
+ log.Err(err).Msg("Failed to migrate portals")
+ return false
}
log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals")
- affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to fix parent portals after split portal migration")
- os.Exit(31)
- return false, nil
- }
- log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration")
- withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx)
- if err != nil {
- log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate")
- os.Exit(31)
- return false, nil
- }
- var roomsToDelete []id.RoomID
- log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver")
- for _, portal := range withoutReceiver {
- if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil {
- log.Err(err).
- Str("portal_id", string(portal.ID)).
- Stringer("mxid", portal.MXID).
- Msg("Failed to delete portal database row that failed to migrate")
- } else if portal.MXID != "" {
- log.Debug().
- Str("portal_id", string(portal.ID)).
- Stringer("mxid", portal.MXID).
- Msg("Marked portal room for deletion from homeserver")
- roomsToDelete = append(roomsToDelete, portal.MXID)
- } else {
- log.Debug().
- Str("portal_id", string(portal.ID)).
- Msg("Deleted portal row with no Matrix room")
- }
- }
br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true")
- log.Info().Msg("Finished split portal migration successfully")
- return affected > 0, func() {
- for _, roomID := range roomsToDelete {
- if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil {
- log.Err(err).
- Stringer("mxid", roomID).
- Msg("Failed to delete portal room that failed to migrate")
- }
- }
- log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate")
- }
+ return affected > 0
}
func (br *Bridge) StartLogins(ctx context.Context) error {
@@ -374,46 +319,6 @@ func (br *Bridge) StartLogins(ctx context.Context) error {
return nil
}
-func (br *Bridge) ResetNetworkConnections() {
- nrn, ok := br.Network.(NetworkResettingNetwork)
- if ok {
- br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections")
- nrn.ResetNetworkConnections()
- return
- }
-
- br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually")
- for _, login := range br.GetAllCachedUserLogins() {
- login.Log.Debug().Msg("Disconnecting and recreating client for network reset")
- ctx := login.Log.WithContext(br.BackgroundCtx)
- login.Client.Disconnect()
- err := login.recreateClient(ctx)
- if err != nil {
- login.Log.Err(err).Msg("Failed to recreate client during network reset")
- login.BridgeState.Send(status.BridgeState{
- StateEvent: status.StateUnknownError,
- Error: "bridgev2-network-reset-fail",
- Info: map[string]any{"go_error": err.Error()},
- })
- } else {
- login.Client.Connect(ctx)
- }
- }
- br.Log.Info().Msg("Finished resetting all user logins")
-}
-
-func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings {
- mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings)
- if ok {
- return mchs.GetHTTPClientSettings()
- }
- return exhttp.SensibleClientSettings
-}
-
-func (br *Bridge) IsStopping() bool {
- return br.stopping.Load()
-}
-
func (br *Bridge) Stop() {
br.stop(false, 0)
}
@@ -424,7 +329,6 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) {
func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) {
br.Log.Info().Msg("Shutting down bridge")
- br.stopping.Store(true)
br.DisappearLoop.Stop()
br.stopBackfillQueue.Set()
br.Matrix.PreStop()
diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go
index eedae1e8..53282e41 100644
--- a/bridgev2/bridgeconfig/backfill.go
+++ b/bridgev2/bridgeconfig/backfill.go
@@ -34,12 +34,10 @@ type BackfillQueueConfig struct {
MaxBatchesOverride map[string]int `yaml:"max_batches_override"`
}
-func (bqc *BackfillQueueConfig) GetOverride(names ...string) int {
- for _, name := range names {
- override, ok := bqc.MaxBatchesOverride[name]
- if ok {
- return override
- }
+func (bqc *BackfillQueueConfig) GetOverride(name string) int {
+ override, ok := bqc.MaxBatchesOverride[name]
+ if !ok {
+ return bqc.MaxBatches
}
- return bqc.MaxBatches
+ return override
}
diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go
index bd6b9c06..9bdee5fe 100644
--- a/bridgev2/bridgeconfig/config.go
+++ b/bridgev2/bridgeconfig/config.go
@@ -33,8 +33,6 @@ type Config struct {
Encryption EncryptionConfig `yaml:"encryption"`
Logging zeroconfig.Config `yaml:"logging"`
- EnvConfigPrefix string `yaml:"env_config_prefix"`
-
ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"`
}
@@ -62,40 +60,36 @@ type CleanupOnLogouts struct {
}
type BridgeConfig struct {
- CommandPrefix string `yaml:"command_prefix"`
- PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
- PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
- AsyncEvents bool `yaml:"async_events"`
- SplitPortals bool `yaml:"split_portals"`
- ResendBridgeInfo bool `yaml:"resend_bridge_info"`
- NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
- BridgeStatusNotices string `yaml:"bridge_status_notices"`
- UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
- UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"`
- BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
- BridgeNotices bool `yaml:"bridge_notices"`
- TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
- OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
- MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
- DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
- CrossRoomReplies bool `yaml:"cross_room_replies"`
- OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
- RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"`
- KickMatrixUsers bool `yaml:"kick_matrix_users"`
- CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
- Relay RelayConfig `yaml:"relay"`
- Permissions PermissionConfig `yaml:"permissions"`
- Backfill BackfillConfig `yaml:"backfill"`
+ CommandPrefix string `yaml:"command_prefix"`
+ PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
+ PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
+ AsyncEvents bool `yaml:"async_events"`
+ SplitPortals bool `yaml:"split_portals"`
+ ResendBridgeInfo bool `yaml:"resend_bridge_info"`
+ NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
+ BridgeStatusNotices string `yaml:"bridge_status_notices"`
+ UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
+ BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
+ BridgeNotices bool `yaml:"bridge_notices"`
+ TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
+ OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
+ MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
+ DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
+ CrossRoomReplies bool `yaml:"cross_room_replies"`
+ OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
+ CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
+ Relay RelayConfig `yaml:"relay"`
+ Permissions PermissionConfig `yaml:"permissions"`
+ Backfill BackfillConfig `yaml:"backfill"`
}
type MatrixConfig struct {
- MessageStatusEvents bool `yaml:"message_status_events"`
- DeliveryReceipts bool `yaml:"delivery_receipts"`
- MessageErrorNotices bool `yaml:"message_error_notices"`
- SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
- FederateRooms bool `yaml:"federate_rooms"`
- UploadFileThreshold int64 `yaml:"upload_file_threshold"`
- GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"`
+ MessageStatusEvents bool `yaml:"message_status_events"`
+ DeliveryReceipts bool `yaml:"delivery_receipts"`
+ MessageErrorNotices bool `yaml:"message_error_notices"`
+ SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
+ FederateRooms bool `yaml:"federate_rooms"`
+ UploadFileThreshold int64 `yaml:"upload_file_threshold"`
}
type AnalyticsConfig struct {
@@ -105,6 +99,7 @@ type AnalyticsConfig struct {
}
type ProvisioningConfig struct {
+ Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"`
DebugEndpoints bool `yaml:"debug_endpoints"`
EnableSessionTransfers bool `yaml:"enable_session_transfers"`
@@ -117,12 +112,10 @@ type DirectMediaConfig struct {
}
type PublicMediaConfig struct {
- Enabled bool `yaml:"enabled"`
- SigningKey string `yaml:"signing_key"`
- Expiry int `yaml:"expiry"`
- HashLength int `yaml:"hash_length"`
- PathPrefix string `yaml:"path_prefix"`
- UseDatabase bool `yaml:"use_database"`
+ Enabled bool `yaml:"enabled"`
+ SigningKey string `yaml:"signing_key"`
+ HashLength int `yaml:"hash_length"`
+ Expiry int `yaml:"expiry"`
}
type DoublePuppetConfig struct {
diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go
index 934613ca..1ef7e18f 100644
--- a/bridgev2/bridgeconfig/encryption.go
+++ b/bridgev2/bridgeconfig/encryption.go
@@ -16,8 +16,6 @@ type EncryptionConfig struct {
Require bool `yaml:"require"`
Appservice bool `yaml:"appservice"`
MSC4190 bool `yaml:"msc4190"`
- MSC4392 bool `yaml:"msc4392"`
- SelfSign bool `yaml:"self_sign"`
PlaintextMentions bool `yaml:"plaintext_mentions"`
diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go
index 954a37c3..fb2a86d6 100644
--- a/bridgev2/bridgeconfig/legacymigrate.go
+++ b/bridgev2/bridgeconfig/legacymigrate.go
@@ -133,7 +133,9 @@ func doMigrateLegacy(helper up.Helper, python bool) {
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"})
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"})
+ CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
+ CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"})
diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go
index 9efe068e..610051e0 100644
--- a/bridgev2/bridgeconfig/permissions.go
+++ b/bridgev2/bridgeconfig/permissions.go
@@ -24,7 +24,6 @@ type Permissions struct {
DoublePuppet bool `yaml:"double_puppet"`
Admin bool `yaml:"admin"`
ManageRelay bool `yaml:"manage_relay"`
- MaxLogins int `yaml:"max_logins"`
}
type PermissionConfig map[string]*Permissions
@@ -41,7 +40,10 @@ func (pc PermissionConfig) IsConfigured() bool {
_, hasExampleDomain := pc["example.com"]
_, hasExampleUser := pc["@admin:example.com"]
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
- return len(pc) > exampleLen
+ if len(pc) <= exampleLen {
+ return false
+ }
+ return true
}
func (pc PermissionConfig) Get(userID id.UserID) Permissions {
diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go
index 92515ea0..b69a1fdb 100644
--- a/bridgev2/bridgeconfig/upgrade.go
+++ b/bridgev2/bridgeconfig/upgrade.go
@@ -33,7 +33,6 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key")
helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices")
helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect")
- helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects")
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
helper.Copy(up.Bool, "bridge", "bridge_notices")
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
@@ -41,8 +40,6 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "mute_only_on_create")
helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages")
helper.Copy(up.Bool, "bridge", "cross_room_replies")
- helper.Copy(up.Bool, "bridge", "revert_failed_state_changes")
- helper.Copy(up.Bool, "bridge", "kick_matrix_users")
helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled")
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private")
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed")
@@ -101,12 +98,12 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
helper.Copy(up.Bool, "matrix", "federate_rooms")
helper.Copy(up.Int, "matrix", "upload_file_threshold")
- helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info")
helper.Copy(up.Str|up.Null, "analytics", "token")
helper.Copy(up.Str|up.Null, "analytics", "url")
helper.Copy(up.Str|up.Null, "analytics", "user_id")
+ helper.Copy(up.Str, "provisioning", "prefix")
if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := random.String(64)
helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret")
@@ -136,8 +133,6 @@ func doUpgrade(helper up.Helper) {
}
helper.Copy(up.Int, "public_media", "expiry")
helper.Copy(up.Int, "public_media", "hash_length")
- helper.Copy(up.Str|up.Null, "public_media", "path_prefix")
- helper.Copy(up.Bool, "public_media", "use_database")
helper.Copy(up.Bool, "backfill", "enabled")
helper.Copy(up.Int, "backfill", "max_initial_messages")
@@ -163,8 +158,6 @@ func doUpgrade(helper up.Helper) {
} else {
helper.Copy(up.Bool, "encryption", "msc4190")
}
- helper.Copy(up.Bool, "encryption", "msc4392")
- helper.Copy(up.Bool, "encryption", "self_sign")
helper.Copy(up.Bool, "encryption", "allow_key_sharing")
if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" {
helper.Set(up.Str, random.String(64), "encryption", "pickle_key")
@@ -187,8 +180,6 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Int, "encryption", "rotation", "messages")
helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation")
- helper.Copy(up.Str|up.Null, "env_config_prefix")
-
helper.Copy(up.Map, "logging")
}
@@ -216,7 +207,6 @@ var SpacedBlocks = [][]string{
{"backfill"},
{"double_puppet"},
{"encryption"},
- {"env_config_prefix"},
{"logging"},
}
diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go
index 96d9fd5c..f31d4e92 100644
--- a/bridgev2/bridgestate.go
+++ b/bridgev2/bridgestate.go
@@ -15,15 +15,12 @@ import (
"time"
"github.com/rs/zerolog"
- "go.mau.fi/util/exfmt"
"maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
)
-var CatchBridgeStateQueuePanics = true
-
type BridgeStateQueue struct {
prevUnsent *status.BridgeState
prevSent *status.BridgeState
@@ -32,13 +29,8 @@ type BridgeStateQueue struct {
bridge *Bridge
login *UserLogin
- firstTransientDisconnect time.Time
- cancelScheduledNotice atomic.Pointer[context.CancelFunc]
-
stopChan chan struct{}
stopReconnect atomic.Pointer[context.CancelFunc]
-
- unknownErrorReconnects int
}
func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) {
@@ -82,63 +74,31 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() {
if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil {
(*cancelFn)()
}
- if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil {
- (*cancelFn)()
- }
}
func (bsq *BridgeStateQueue) loop() {
- if CatchBridgeStateQueuePanics {
- defer func() {
- err := recover()
- if err != nil {
- bsq.login.Log.Error().
- Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
- Any(zerolog.ErrorFieldName, err).
- Msg("Panic in bridge state loop")
- }
- }()
- }
+ defer func() {
+ err := recover()
+ if err != nil {
+ bsq.login.Log.Error().
+ Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
+ Any(zerolog.ErrorFieldName, err).
+ Msg("Panic in bridge state loop")
+ }
+ }()
for state := range bsq.ch {
bsq.immediateSendBridgeState(state)
}
}
-func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
- log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
- ctx := log.WithContext(bsq.bridge.BackgroundCtx)
- if !bsq.waitForTransientDisconnectReconnect(ctx) {
- return
- }
- prevUnsent := bsq.GetPrevUnsent()
- prev := bsq.GetPrev()
- if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent ||
- prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect {
- log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice")
- return
- }
- log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice")
- bsq.sendNotice(ctx, triggeredBy, true)
-}
-
-func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) {
+func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) {
noticeConfig := bsq.bridge.Config.BridgeStatusNotices
isError := state.StateEvent == status.StateBadCredentials ||
state.StateEvent == status.StateUnknownError ||
- state.UserAction == status.UserActionOpenNative ||
- (isDelayed && state.StateEvent == status.StateTransientDisconnect)
+ state.UserAction == status.UserActionOpenNative
sendNotice := noticeConfig == "all" || (noticeConfig == "errors" &&
(isError || (bsq.errorSent && state.StateEvent == status.StateConnected)))
- if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError {
- bsq.firstTransientDisconnect = time.Time{}
- }
if !sendNotice {
- if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect {
- if bsq.firstTransientDisconnect.IsZero() {
- bsq.firstTransientDisconnect = time.Now()
- }
- go bsq.scheduleNotice(state)
- }
return
}
managementRoom, err := bsq.login.User.GetManagementRoom(ctx)
@@ -154,9 +114,6 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
if state.Error != "" {
message += fmt.Sprintf(" (`%s`)", state.Error)
}
- if isDelayed {
- message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay))
- }
if state.Message != "" {
message += fmt.Sprintf(": %s", state.Message)
}
@@ -194,14 +151,8 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat
} else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError {
log.Debug().Msg("Not reconnecting as the previous state was not an unknown error")
return
- } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects {
- log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached")
- return
}
- bsq.unknownErrorReconnects++
- log.Info().
- Int("reconnect_num", bsq.unknownErrorReconnects).
- Msg("Disconnecting and reconnecting login due to unknown error")
+ log.Info().Msg("Disconnecting and reconnecting login due to unknown error")
bsq.login.Disconnect()
log.Debug().Msg("Disconnection finished, recreating client and reconnecting")
err := bsq.login.recreateClient(ctx)
@@ -220,30 +171,14 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b
return false
}
reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2))
- return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect)
-}
-
-const TransientDisconnectNoticeDelay = 3 * time.Minute
-
-func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool {
- timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay))
- zerolog.Ctx(ctx).Trace().
- Stringer("duration", timeUntilSchedule).
- Msg("Waiting before sending notice about transient disconnect")
- return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice)
-}
-
-func (bsq *BridgeStateQueue) waitForReconnect(
- ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc],
-) bool {
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
- if oldCancel := ptr.Swap(&cancel); oldCancel != nil {
+ if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil {
(*oldCancel)()
}
select {
case <-time.After(reconnectIn):
- return ptr.CompareAndSwap(&cancel, nil)
+ return bsq.stopReconnect.CompareAndSwap(&cancel, nil)
case <-cancelCtx.Done():
return false
case <-bsq.stopChan:
@@ -263,7 +198,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState)
}
ctx := bsq.login.Log.WithContext(context.Background())
- bsq.sendNotice(ctx, state, false)
+ bsq.sendNotice(ctx, state)
retryIn := 2
for {
diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go
index 1cae98fe..4c93dbd4 100644
--- a/bridgev2/commands/debug.go
+++ b/bridgev2/commands/debug.go
@@ -7,13 +7,10 @@
package commands
import (
- "encoding/json"
"strings"
- "time"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/event"
)
var CommandRegisterPush = &FullHandler{
@@ -62,64 +59,3 @@ var CommandRegisterPush = &FullHandler{
RequiresLogin: true,
NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI],
}
-
-var CommandSendAccountData = &FullHandler{
- Func: func(ce *Event) {
- if len(ce.Args) < 2 {
- ce.Reply("Usage: `$cmdprefix debug-account-data ")
- return
- }
- var content event.Content
- evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType}
- ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0]))
- err := json.Unmarshal([]byte(ce.RawArgs), &content)
- if err != nil {
- ce.Reply("Failed to parse JSON: %v", err)
- return
- }
- err = content.ParseRaw(evtType)
- if err != nil {
- ce.Reply("Failed to deserialize content: %v", err)
- return
- }
- res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{
- Sender: ce.User.MXID,
- Type: evtType,
- Timestamp: time.Now().UnixMilli(),
- RoomID: ce.RoomID,
- Content: content,
- })
- ce.Reply("Result: %+v", res)
- },
- Name: "debug-account-data",
- Help: HelpMeta{
- Section: HelpSectionAdmin,
- Description: "Send a room account data event to the bridge",
- Args: "<_type_> <_content_>",
- },
- RequiresAdmin: true,
- RequiresPortal: true,
- RequiresLogin: true,
-}
-
-var CommandResetNetwork = &FullHandler{
- Func: func(ce *Event) {
- if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") {
- nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork)
- if ok {
- nrn.ResetHTTPTransport()
- } else {
- ce.Reply("Network connector does not support resetting HTTP transport")
- }
- }
- ce.Bridge.ResetNetworkConnections()
- ce.React("✅️")
- },
- Name: "debug-reset-network",
- Help: HelpMeta{
- Section: HelpSectionAdmin,
- Description: "Reset network connections to the remote network",
- Args: "[--reset-transport]",
- },
- RequiresAdmin: true,
-}
diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go
index 96d62d3e..3544998c 100644
--- a/bridgev2/commands/login.go
+++ b/bridgev2/commands/login.go
@@ -70,15 +70,6 @@ func fnLogin(ce *Event) {
}
ce.Args = ce.Args[1:]
}
- if reauth == nil && ce.User.HasTooManyLogins() {
- ce.Reply(
- "You have reached the maximum number of logins (%d). "+
- "Please logout from an existing login before creating a new one. "+
- "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.",
- ce.User.Permissions.MaxLogins,
- )
- return
- }
flows := ce.Bridge.Network.GetLoginFlows()
var chosenFlowID string
if len(ce.Args) > 0 {
@@ -121,7 +112,6 @@ func fnLogin(ce *Event) {
ce.Reply("Failed to start login: %v", err)
return
}
- ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process")
nextStep = checkLoginCommandDirectParams(ce, login, nextStep)
if nextStep != nil {
@@ -200,14 +190,11 @@ type userInputLoginCommandState struct {
func (uilcs *userInputLoginCommandState) promptNext(ce *Event) {
field := uilcs.RemainingFields[0]
- parts := []string{fmt.Sprintf("Please enter your %s", field.Name)}
if field.Description != "" {
- parts = append(parts, field.Description)
+ ce.Reply("Please enter your %s\n%s", field.Name, field.Description)
+ } else {
+ ce.Reply("Please enter your %s", field.Name)
}
- if len(field.Options) > 0 {
- parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `")))
- }
- ce.Reply(strings.Join(parts, "\n"))
StoreCommandState(ce.User, &CommandState{
Next: MinimalCommandHandlerFunc(uilcs.submitNext),
Action: "Login",
@@ -252,19 +239,14 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return fmt.Errorf("failed to upload image: %w", err)
}
content := &event.MessageEventContent{
- MsgType: event.MsgImage,
- FileName: "qr.png",
- URL: qrMXC,
- File: qrFile,
+ MsgType: event.MsgImage,
+ FileName: "qr.png",
+ URL: qrMXC,
+ File: qrFile,
+
Body: qr,
Format: event.FormatHTML,
FormattedBody: fmt.Sprintf("%s
", html.EscapeString(qr)),
- Info: &event.FileInfo{
- MimeType: "image/png",
- Width: qrSizePx,
- Height: qrSizePx,
- Size: len(qrData),
- },
}
if *prevEventID != "" {
content.SetEdit(*prevEventID)
@@ -279,36 +261,6 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return nil
}
-func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
- for _, att := range atts {
- if att.FileName == "" {
- return fmt.Errorf("missing attachment filename")
- }
- mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
- if err != nil {
- return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
- }
- content := &event.MessageEventContent{
- MsgType: att.Type,
- FileName: att.FileName,
- URL: mxc,
- File: file,
- Info: &event.FileInfo{
- MimeType: att.Info.MimeType,
- Width: att.Info.Width,
- Height: att.Info.Height,
- Size: att.Info.Size,
- },
- Body: att.FileName,
- }
- _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
- if err != nil {
- return nil
- }
- }
- return nil
-}
-
type contextKey int
const (
@@ -321,13 +273,6 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
prevEvent = new(id.EventID)
ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent)
}
- cancelCtx, cancelFunc := context.WithCancel(ce.Ctx)
- defer cancelFunc()
- StoreCommandState(ce.User, &CommandState{
- Action: "Login",
- Cancel: cancelFunc,
- })
- defer StoreCommandState(ce.User, nil)
switch step.DisplayAndWaitParams.Type {
case bridgev2.LoginDisplayTypeQR:
err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent)
@@ -347,7 +292,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait,
login.Cancel()
return
}
- nextStep, err := login.Wait(cancelCtx)
+ nextStep, err := login.Wait(ce.Ctx)
// Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited)
if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) {
_, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
@@ -500,7 +445,6 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
}
func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
- ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
if step.Instructions != "" {
ce.Reply(step.Instructions)
}
@@ -515,10 +459,6 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte
Override: override,
}).prompt(ce)
case bridgev2.LoginStepTypeUserInput:
- err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
- if err != nil {
- ce.Reply("Failed to send attachments: %v", err)
- }
(&userInputLoginCommandState{
Login: login.(bridgev2.LoginProcessUserInput),
RemainingFields: step.UserInputParams.Fields,
diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go
index 391c3685..c28e3a32 100644
--- a/bridgev2/commands/processor.go
+++ b/bridgev2/commands/processor.go
@@ -41,11 +41,10 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
}
proc.AddHandlers(
CommandHelp, CommandCancel,
- CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
- CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
+ CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
CommandSetRelay, CommandUnsetRelay,
- CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
+ CommandResolveIdentifier, CommandStartChat, CommandSearch,
CommandSudo, CommandDoIn,
)
return proc
diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go
index 94c19739..af756c87 100644
--- a/bridgev2/commands/relay.go
+++ b/bridgev2/commands/relay.go
@@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
}
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
var relay *bridgev2.UserLogin
- if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
+ if len(ce.Args) == 0 {
relay = ce.User.GetDefaultLogin()
isLoggedIn := relay != nil
if onlySetDefaultRelays {
@@ -73,19 +73,9 @@ func fnSetRelay(ce *Event) {
}
}
} else {
- var targetID networkid.UserLoginID
- if ce.Portal.Receiver != "" {
- targetID = ce.Portal.Receiver
- if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
- ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
- return
- }
- } else {
- targetID = networkid.UserLoginID(ce.Args[0])
- }
- relay = ce.Bridge.GetCachedUserLoginByID(targetID)
+ relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
if relay == nil {
- ce.Reply("User login with ID `%s` not found", targetID)
+ ce.Reply("User login with ID `%s` not found", ce.Args[0])
return
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
// All good
diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go
index c7b05a6e..719d3dd5 100644
--- a/bridgev2/commands/startchat.go
+++ b/bridgev2/commands/startchat.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,21 +8,13 @@ package commands
import (
"context"
- "errors"
"fmt"
"html"
- "maps"
- "slices"
"strings"
"time"
- "github.com/rs/zerolog"
-
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/bridgev2/provisionutil"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
@@ -38,35 +30,6 @@ var CommandResolveIdentifier = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
-var CommandSyncChat = &FullHandler{
- Func: func(ce *Event) {
- login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to find login for sync")
- ce.Reply("Failed to find login: %v", err)
- return
- } else if login == nil {
- ce.Reply("No login found for sync")
- return
- }
- info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get chat info for sync")
- ce.Reply("Failed to get chat info: %v", err)
- return
- }
- ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{})
- ce.React("✅️")
- },
- Name: "sync-portal",
- Help: HelpMeta{
- Section: HelpSectionChats,
- Description: "Sync the current portal room",
- },
- RequiresPortal: true,
- RequiresLogin: true,
-}
-
var CommandStartChat = &FullHandler{
Func: fnResolveIdentifier,
Name: "start-chat",
@@ -80,15 +43,9 @@ var CommandStartChat = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
-func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
- var remainingArgs []string
- if len(ce.Args) > 1 {
- remainingArgs = ce.Args[1:]
- }
- var login *bridgev2.UserLogin
- if len(ce.Args) > 0 {
- login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
- }
+func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
+ remainingArgs := ce.Args[1:]
+ login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
if login == nil || login.UserMXID != ce.User.MXID {
remainingArgs = ce.Args
login = ce.User.GetDefaultLogin()
@@ -100,13 +57,24 @@ func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*
return login, api, remainingArgs
}
-func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string {
- if resp.MXID != "" {
- return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL())
- } else if resp.Name != "" {
- return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name)
+func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string {
+ var targetName string
+ var targetMXID id.UserID
+ if resp.Ghost != nil {
+ if resp.UserInfo != nil {
+ resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
+ }
+ targetName = resp.Ghost.Name
+ targetMXID = resp.Ghost.Intent.GetMXID()
+ } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
+ targetName = *resp.UserInfo.Name
+ }
+ if targetMXID != "" {
+ return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL())
+ } else if targetName != "" {
+ return fmt.Sprintf("`%s` / %s", resp.UserID, targetName)
} else {
- return fmt.Sprintf("`%s`", resp.ID)
+ return fmt.Sprintf("`%s`", resp.UserID)
}
}
@@ -119,137 +87,65 @@ func fnResolveIdentifier(ce *Event) {
if api == nil {
return
}
- allLogins := ce.User.GetUserLogins()
createChat := ce.Command == "start-chat" || ce.Command == "pm"
identifier := strings.Join(identifierParts, " ")
- resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat)
- for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ {
- resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat)
- }
+ resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat)
if err != nil {
+ ce.Log.Err(err).Msg("Failed to resolve identifier")
ce.Reply("Failed to resolve identifier: %v", err)
return
} else if resp == nil {
ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true)
return
}
- formattedName := formatResolveIdentifierResult(resp)
+ formattedName := formatResolveIdentifierResult(ce.Ctx, resp)
if createChat {
- name := resp.Portal.Name
- if name == "" {
- name = resp.Portal.MXID.String()
+ if resp.Chat == nil {
+ ce.Reply("Interface error: network connector did not return chat for create chat request")
+ return
}
- if !resp.JustCreated {
- ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
+ portal := resp.Chat.Portal
+ if portal == nil {
+ portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get portal")
+ ce.Reply("Failed to get portal: %v", err)
+ return
+ }
+ }
+ if resp.Chat.PortalInfo == nil {
+ resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to get portal info")
+ ce.Reply("Failed to get portal info: %v", err)
+ return
+ }
+ }
+ if portal.MXID != "" {
+ name := portal.Name
+ if name == "" {
+ name = portal.MXID.String()
+ }
+ portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{})
+ ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
} else {
- ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
+ err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo)
+ if err != nil {
+ ce.Log.Err(err).Msg("Failed to create room")
+ ce.Reply("Failed to create room: %v", err)
+ return
+ }
+ name := portal.Name
+ if name == "" {
+ name = portal.MXID.String()
+ }
+ ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
}
} else {
ce.Reply("Found %s", formattedName)
}
}
-var CommandCreateGroup = &FullHandler{
- Func: fnCreateGroup,
- Name: "create-group",
- Aliases: []string{"create"},
- Help: HelpMeta{
- Section: HelpSectionChats,
- Description: "Create a new group chat for the current Matrix room",
- Args: "[_group type_]",
- },
- RequiresLogin: true,
- NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
-}
-
-func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
- evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
- } else if evt != nil {
- content, _ = evt.Content.Parsed.(T)
- }
- return
-}
-
-func fnCreateGroup(ce *Event) {
- ce.Bridge.Matrix.GetCapabilities()
- login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
- if api == nil {
- return
- }
- stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
- if !ok {
- ce.Reply("Matrix connector doesn't support fetching room state")
- return
- }
- members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
- if err != nil {
- ce.Log.Err(err).Msg("Failed to get room members for group creation")
- ce.Reply("Failed to get room members: %v", err)
- return
- }
- caps := ce.Bridge.Network.GetCapabilities()
- params := &bridgev2.GroupCreateParams{
- Username: "",
- Participants: make([]networkid.UserID, 0, len(members)-2),
- Parent: nil, // TODO check space parent event
- Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
- Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
- Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
- Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
- RoomID: ce.RoomID,
- }
- for userID, member := range members {
- if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
- continue
- }
- if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
- params.Participants = append(params.Participants, parsedUserID)
- } else if !ce.Bridge.Config.SplitPortals {
- if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
- ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
- } else if user != nil {
- // TODO add user logins to participants
- //for _, login := range user.GetUserLogins() {
- // params.Participants = append(params.Participants, login.GetUserID())
- //}
- }
- }
- }
-
- if len(caps.Provisioning.GroupCreation) == 0 {
- ce.Reply("No group creation types defined in network capabilities")
- return
- } else if len(remainingArgs) > 0 {
- params.Type = remainingArgs[0]
- } else if len(caps.Provisioning.GroupCreation) == 1 {
- for params.Type = range caps.Provisioning.GroupCreation {
- // The loop assigns the variable we want
- }
- } else {
- types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
- ce.Reply("Please specify type of group to create: `%s`", types)
- return
- }
- resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
- if err != nil {
- ce.Reply("Failed to create group: %v", err)
- return
- }
- var postfix string
- if len(resp.FailedParticipants) > 0 {
- failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
- i := 0
- for participantID, meta := range resp.FailedParticipants {
- failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
- i++
- }
- postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
- }
- ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
-}
-
var CommandSearch = &FullHandler{
Func: fnSearch,
Name: "search",
@@ -267,67 +163,35 @@ func fnSearch(ce *Event) {
ce.Reply("Usage: `$cmdprefix search `")
return
}
- login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
+ _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
if api == nil {
return
}
- resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " "))
+ results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " "))
if err != nil {
+ ce.Log.Err(err).Msg("Failed to search for users")
ce.Reply("Failed to search for users: %v", err)
return
}
- resultsString := make([]string, len(resp.Results))
- for i, res := range resp.Results {
- formattedName := formatResolveIdentifierResult(res)
+ resultsString := make([]string, len(results))
+ for i, res := range results {
+ formattedName := formatResolveIdentifierResult(ce.Ctx, res)
resultsString[i] = fmt.Sprintf("* %s", formattedName)
- if res.Portal != nil && res.Portal.MXID != "" {
- portalName := res.Portal.Name
- if portalName == "" {
- portalName = res.Portal.MXID.String()
+ if res.Chat != nil {
+ if res.Chat.Portal == nil {
+ res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey)
+ if err != nil {
+ ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal")
+ }
+ }
+ if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" {
+ portalName := res.Chat.Portal.Name
+ if portalName == "" {
+ portalName = res.Chat.Portal.MXID.String()
+ }
+ resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL())
}
- resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL())
}
}
ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n"))
}
-
-var CommandMute = &FullHandler{
- Func: fnMute,
- Name: "mute",
- Aliases: []string{"unmute"},
- Help: HelpMeta{
- Section: HelpSectionChats,
- Description: "Mute or unmute a chat on the remote network",
- Args: "[duration]",
- },
- RequiresPortal: true,
- RequiresLogin: true,
- NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI],
-}
-
-func fnMute(ce *Event) {
- _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats")
- var mutedUntil int64
- if ce.Command == "mute" {
- mutedUntil = -1
- if len(ce.Args) > 0 {
- duration, err := time.ParseDuration(ce.Args[0])
- if err != nil {
- ce.Reply("Invalid duration: %v", err)
- return
- }
- mutedUntil = time.Now().Add(duration).UnixMilli()
- }
- }
- err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{
- MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{
- Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil},
- Portal: ce.Portal,
- },
- })
- if err != nil {
- ce.Reply("Failed to %s chat: %v", ce.Command, err)
- } else {
- ce.React("✅️")
- }
-}
diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go
index 1f920640..224ae626 100644
--- a/bridgev2/database/backfillqueue.go
+++ b/bridgev2/database/backfillqueue.go
@@ -78,11 +78,6 @@ const (
dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11
WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3
`
- markBackfillTaskNotDoneQuery = `
- UPDATE backfill_task
- SET is_done = false
- WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND user_login_id = $4
- `
getNextBackfillQuery = `
SELECT
bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done,
@@ -132,10 +127,6 @@ func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) erro
return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...)
}
-func (btq *BackfillTaskQuery) MarkNotDone(ctx context.Context, portalKey networkid.PortalKey, userLoginID networkid.UserLoginID) error {
- return btq.Exec(ctx, markBackfillTaskNotDoneQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver, userLoginID)
-}
-
func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) {
return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano())
}
diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go
index 05abddf0..f1789441 100644
--- a/bridgev2/database/database.go
+++ b/bridgev2/database/database.go
@@ -7,7 +7,13 @@
package database
import (
+ "encoding/json"
+ "reflect"
+ "strings"
+
"go.mau.fi/util/dbutil"
+ "golang.org/x/exp/constraints"
+ "golang.org/x/exp/maps"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -28,7 +34,6 @@ type Database struct {
UserPortal *UserPortalQuery
BackfillTask *BackfillTaskQuery
KV *KVQuery
- PublicMedia *PublicMediaQuery
}
type MetaMerger interface {
@@ -136,12 +141,6 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa
BridgeID: bridgeID,
Database: db,
},
- PublicMedia: &PublicMediaQuery{
- BridgeID: bridgeID,
- QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia {
- return &PublicMedia{}
- }),
- },
}
}
@@ -152,3 +151,55 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
panic("bridge ID mismatch")
}
}
+
+func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
+ if val, found := m[key]; found {
+ floatVal, ok := val.(float64)
+ if ok {
+ return T(floatVal), true
+ }
+ tVal, ok := val.(T)
+ if ok {
+ return tVal, true
+ }
+ }
+ return 0, false
+}
+
+func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
+ err := json.Unmarshal(input, data)
+ if err != nil {
+ return err
+ }
+ err = json.Unmarshal(input, extra)
+ if err != nil {
+ return err
+ }
+ if *extra == nil {
+ *extra = make(map[string]any)
+ }
+ return nil
+}
+
+func marshalMerge(data any, extra map[string]any) ([]byte, error) {
+ if extra == nil {
+ return json.Marshal(data)
+ }
+ merged := make(map[string]any)
+ maps.Copy(merged, extra)
+ dataRef := reflect.ValueOf(data).Elem()
+ dataType := dataRef.Type()
+ for _, field := range reflect.VisibleFields(dataType) {
+ parts := strings.Split(field.Tag.Get("json"), ",")
+ if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
+ continue
+ }
+ fieldVal := dataRef.FieldByIndex(field.Index)
+ if fieldVal.IsZero() {
+ delete(merged, parts[0])
+ } else {
+ merged[parts[0]] = fieldVal.Interface()
+ }
+ }
+ return json.Marshal(merged)
+}
diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go
index df36b205..23db1448 100644
--- a/bridgev2/database/disappear.go
+++ b/bridgev2/database/disappear.go
@@ -12,94 +12,56 @@ import (
"time"
"go.mau.fi/util/dbutil"
- "go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
-// Deprecated: use [event.DisappearingType]
-type DisappearingType = event.DisappearingType
+// DisappearingType represents the type of a disappearing message timer.
+type DisappearingType string
-// Deprecated: use constants in event package
const (
- DisappearingTypeNone = event.DisappearingTypeNone
- DisappearingTypeAfterRead = event.DisappearingTypeAfterRead
- DisappearingTypeAfterSend = event.DisappearingTypeAfterSend
+ DisappearingTypeNone DisappearingType = ""
+ DisappearingTypeAfterRead DisappearingType = "after_read"
+ DisappearingTypeAfterSend DisappearingType = "after_send"
)
// DisappearingSetting represents a disappearing message timer setting
// by combining a type with a timer and an optional start timestamp.
type DisappearingSetting struct {
- Type event.DisappearingType
+ Type DisappearingType
Timer time.Duration
DisappearAt time.Time
}
-func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting {
- if evt == nil || evt.Type == event.DisappearingTypeNone {
- return DisappearingSetting{}
- }
- return DisappearingSetting{
- Type: evt.Type,
- Timer: evt.Timer.Duration,
- }
-}
-
-func (ds DisappearingSetting) Normalize() DisappearingSetting {
- if ds.Type == event.DisappearingTypeNone {
- ds.Timer = 0
- } else if ds.Timer == 0 {
- ds.Type = event.DisappearingTypeNone
- }
- return ds
-}
-
-func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting {
- ds.DisappearAt = start.Add(ds.Timer)
- return ds
-}
-
-func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer {
- if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 {
- return &event.BeeperDisappearingTimer{}
- }
- return &event.BeeperDisappearingTimer{
- Type: ds.Type,
- Timer: jsontime.MS(ds.Timer),
- }
-}
-
type DisappearingMessageQuery struct {
BridgeID networkid.BridgeID
*dbutil.QueryHelper[*DisappearingMessage]
}
type DisappearingMessage struct {
- BridgeID networkid.BridgeID
- RoomID id.RoomID
- EventID id.EventID
- Timestamp time.Time
+ BridgeID networkid.BridgeID
+ RoomID id.RoomID
+ EventID id.EventID
DisappearingSetting
}
const (
upsertDisappearingMessageQuery = `
- INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
+ INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at)
+ VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at
`
startDisappearingMessagesQuery = `
UPDATE disappearing_message
SET disappear_at=$1 + timer
- WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4
- RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
+ WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read'
+ RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at
`
getUpcomingDisappearingMessagesQuery = `
- SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
+ SELECT bridge_id, mx_room, mxid, type, timer, disappear_at
FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2
- ORDER BY disappear_at LIMIT $3
+ ORDER BY disappear_at
`
deleteDisappearingMessageQuery = `
DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2
@@ -111,12 +73,12 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe
return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...)
}
-func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) {
- return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano())
+func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) {
+ return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID)
}
-func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) {
- return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit)
+func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) {
+ return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano())
}
func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error {
@@ -124,19 +86,17 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even
}
func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
- var timestamp int64
var disappearAt sql.NullInt64
- err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt)
+ err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt)
if err != nil {
return nil, err
}
if disappearAt.Valid {
d.DisappearAt = time.Unix(0, disappearAt.Int64)
}
- d.Timestamp = time.Unix(0, timestamp)
return d, nil
}
func (d *DisappearingMessage) sqlVariables() []any {
- return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
+ return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
}
diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go
index 16af35ca..c32929ad 100644
--- a/bridgev2/database/ghost.go
+++ b/bridgev2/database/ghost.go
@@ -7,17 +7,12 @@
package database
import (
- "bytes"
"context"
"encoding/hex"
- "encoding/json"
- "fmt"
"go.mau.fi/util/dbutil"
- "go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
)
@@ -27,55 +22,6 @@ type GhostQuery struct {
*dbutil.QueryHelper[*Ghost]
}
-type ExtraProfile map[string]json.RawMessage
-
-func (ep *ExtraProfile) Set(key string, value any) error {
- if key == "displayname" || key == "avatar_url" {
- return fmt.Errorf("cannot set reserved profile key %q", key)
- }
- marshaled, err := json.Marshal(value)
- if err != nil {
- return err
- }
- if *ep == nil {
- *ep = make(ExtraProfile)
- }
- (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled)
- return nil
-}
-
-func (ep *ExtraProfile) With(key string, value any) *ExtraProfile {
- exerrors.PanicIfNotNil(ep.Set(key, value))
- return ep
-}
-
-func canonicalizeIfObject(data json.RawMessage) json.RawMessage {
- if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
- return canonicaljson.CanonicalJSONAssumeValid(data)
- }
- return data
-}
-
-func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) {
- if len(*ep) == 0 {
- return
- }
- if *dest == nil {
- *dest = make(ExtraProfile)
- }
- for key, val := range *ep {
- if key == "displayname" || key == "avatar_url" {
- continue
- }
- existing, exists := (*dest)[key]
- if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) {
- (*dest)[key] = val
- changed = true
- }
- }
- return
-}
-
type Ghost struct {
BridgeID networkid.BridgeID
ID networkid.UserID
@@ -89,14 +35,13 @@ type Ghost struct {
ContactInfoSet bool
IsBot bool
Identifiers []string
- ExtraProfile ExtraProfile
Metadata any
}
const (
getGhostBaseQuery = `
SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
+ name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
FROM ghost
`
getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2`
@@ -104,14 +49,13 @@ const (
insertGhostQuery = `
INSERT INTO ghost (
bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
+ name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
`
updateGhostQuery = `
UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6,
- name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10,
- identifiers=$11, extra_profile=$12, metadata=$13
+ name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12
WHERE bridge_id=$1 AND id=$2
`
)
@@ -142,7 +86,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) {
&g.BridgeID, &g.ID,
&g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC,
&g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot,
- dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
+ dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
)
if err != nil {
return nil, err
@@ -172,6 +116,6 @@ func (g *Ghost) sqlVariables() []any {
g.BridgeID, g.ID,
g.Name, g.AvatarID, avatarHash, g.AvatarMXC,
g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot,
- dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
+ dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
}
}
diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go
index bca26ed5..5a1af019 100644
--- a/bridgev2/database/kvstore.go
+++ b/bridgev2/database/kvstore.go
@@ -20,10 +20,8 @@ import (
type Key string
const (
- KeySplitPortalsEnabled Key = "split_portals_enabled"
- KeyBridgeInfoVersion Key = "bridge_info_version"
- KeyEncryptionStateResynced Key = "encryption_state_resynced"
- KeyRecoveryKey Key = "recovery_key"
+ KeySplitPortalsEnabled Key = "split_portals_enabled"
+ KeyBridgeInfoVersion Key = "bridge_info_version"
)
type KVQuery struct {
diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go
index 4fd599a8..9b3b1493 100644
--- a/bridgev2/database/message.go
+++ b/bridgev2/database/message.go
@@ -11,12 +11,9 @@ import (
"crypto/sha256"
"database/sql"
"encoding/base64"
- "fmt"
"strings"
- "sync"
"time"
- "github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -27,7 +24,6 @@ type MessageQuery struct {
BridgeID networkid.BridgeID
MetaType MetaTypeCreator
*dbutil.QueryHelper[*Message]
- chunkDeleteLock sync.Mutex
}
type Message struct {
@@ -68,8 +64,8 @@ const (
getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`
getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5`
getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1`
- getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1`
- getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1`
+ getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1`
+ getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1`
getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4`
getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1`
@@ -100,10 +96,6 @@ const (
deleteMessagePartByRowIDQuery = `
DELETE FROM message WHERE bridge_id=$1 AND rowid=$2
`
- deleteMessageChunkQuery = `
- DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5
- `
- getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1`
)
func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) {
@@ -188,85 +180,6 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error {
return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID)
}
-func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) {
- res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID)
- if err != nil {
- return 0, err
- }
- return res.RowsAffected()
-}
-
-func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) {
- err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID)
- return
-}
-
-const deleteChunkSize = 100_000
-
-func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error {
- if mq.GetDB().Dialect != dbutil.SQLite {
- return nil
- }
- log := zerolog.Ctx(ctx).With().
- Str("action", "delete messages in chunks").
- Stringer("portal_key", portal).
- Logger()
- if !mq.chunkDeleteLock.TryLock() {
- log.Warn().Msg("Portal deletion lock is being held, waiting...")
- mq.chunkDeleteLock.Lock()
- log.Debug().Msg("Acquired portal deletion lock after waiting")
- }
- defer mq.chunkDeleteLock.Unlock()
- total, err := mq.CountMessagesInPortal(ctx, portal)
- if err != nil {
- return fmt.Errorf("failed to count messages in portal: %w", err)
- } else if total < deleteChunkSize/3 {
- return nil
- }
- globalMaxRowID, err := mq.getMaxRowID(ctx)
- if err != nil {
- return fmt.Errorf("failed to get max row ID: %w", err)
- }
- log.Debug().
- Int("total_count", total).
- Int64("global_max_row_id", globalMaxRowID).
- Msg("Portal has lots of messages, deleting in chunks to avoid database locks")
- maxRowID := int64(deleteChunkSize)
- globalMaxRowID += deleteChunkSize * 1.2
- var dbTimeUsed time.Duration
- globalStart := time.Now()
- for total > 500 && maxRowID < globalMaxRowID {
- start := time.Now()
- count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID)
- duration := time.Since(start)
- dbTimeUsed += duration
- if err != nil {
- return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err)
- }
- total -= int(count)
- maxRowID += deleteChunkSize
- sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond))
- log.Debug().
- Int64("max_row_id", maxRowID).
- Int64("deleted_count", count).
- Int("remaining_count", total).
- Dur("duration", duration).
- Dur("sleep_time", sleepTime).
- Msg("Deleted chunk of messages")
- select {
- case <-time.After(sleepTime):
- case <-ctx.Done():
- return ctx.Err()
- }
- }
- log.Debug().
- Int("remaining_count", total).
- Dur("db_time_used", dbTimeUsed).
- Dur("total_duration", time.Since(globalStart)).
- Msg("Finished chunked delete of messages in portal")
- return nil
-}
-
func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) {
err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count)
return
diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go
index 0e6be286..17e44b09 100644
--- a/bridgev2/database/portal.go
+++ b/bridgev2/database/portal.go
@@ -16,7 +16,6 @@ import (
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -35,20 +34,9 @@ type PortalQuery struct {
*dbutil.QueryHelper[*Portal]
}
-type CapStateFlags uint32
-
-func (csf CapStateFlags) Has(flag CapStateFlags) bool {
- return csf&flag != 0
-}
-
-const (
- CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota
-)
-
type CapabilityState struct {
Source networkid.UserLoginID `json:"source"`
ID string `json:"id"`
- Flags CapStateFlags `json:"flags"`
}
type Portal struct {
@@ -56,31 +44,30 @@ type Portal struct {
networkid.PortalKey
MXID id.RoomID
- ParentKey networkid.PortalKey
- RelayLoginID networkid.UserLoginID
- OtherUserID networkid.UserID
- Name string
- Topic string
- AvatarID networkid.AvatarID
- AvatarHash [32]byte
- AvatarMXC id.ContentURIString
- NameSet bool
- TopicSet bool
- AvatarSet bool
- NameIsCustom bool
- InSpace bool
- MessageRequest bool
- RoomType RoomType
- Disappear DisappearingSetting
- CapState CapabilityState
- Metadata any
+ ParentKey networkid.PortalKey
+ RelayLoginID networkid.UserLoginID
+ OtherUserID networkid.UserID
+ Name string
+ Topic string
+ AvatarID networkid.AvatarID
+ AvatarHash [32]byte
+ AvatarMXC id.ContentURIString
+ NameSet bool
+ TopicSet bool
+ AvatarSet bool
+ NameIsCustom bool
+ InSpace bool
+ RoomType RoomType
+ Disappear DisappearingSetting
+ CapState CapabilityState
+ Metadata any
}
const (
getPortalBaseQuery = `
SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
- name_set, topic_set, avatar_set, name_is_custom, in_space, message_request,
+ name_set, topic_set, avatar_set, name_is_custom, in_space,
room_type, disappear_type, disappear_timer, cap_state,
metadata
FROM portal
@@ -89,9 +76,7 @@ const (
getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')`
getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL`
- getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC`
getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2`
- getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3`
getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1`
getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3`
@@ -102,11 +87,11 @@ const (
bridge_id, id, receiver, mxid,
parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, topic_set, name_is_custom, in_space, message_request,
+ name_set, avatar_set, topic_set, name_is_custom, in_space,
room_type, disappear_type, disappear_timer, cap_state,
metadata, relay_bridge_id
) VALUES (
- $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24,
+ $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23,
CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END
)
`
@@ -115,8 +100,8 @@ const (
SET mxid=$4, parent_id=$5, parent_receiver=$6,
relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END,
other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13,
- name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19,
- room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24
+ name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18,
+ room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23
WHERE bridge_id=$1 AND id=$2 AND receiver=$3
`
deletePortalQuery = `
@@ -126,33 +111,15 @@ const (
reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3`
migrateToSplitPortalsQuery = `
UPDATE portal
- SET receiver=new_receiver
- FROM (
- SELECT bridge_id, id, COALESCE((
- SELECT login_id
- FROM user_portal
- WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
- LIMIT 1
- ), (
- SELECT login_id
- FROM user_portal
- WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id
- LIMIT 1
- ), (
- SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
- ), '') AS new_receiver
- FROM portal
- WHERE receiver='' AND bridge_id=$1
- ) updates
- WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS (
- SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver
- )
- `
- fixParentsAfterSplitPortalMigrationQuery = `
- UPDATE portal
- SET parent_receiver=receiver
- WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''
- AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver);
+ SET receiver=COALESCE((
+ SELECT login_id
+ FROM user_portal
+ WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
+ LIMIT 1
+ ), (
+ SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
+ ), '')
+ WHERE receiver='' AND bridge_id=$1
`
)
@@ -180,10 +147,6 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID)
}
-func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) {
- return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID)
-}
-
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID)
}
@@ -192,10 +155,6 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid.
return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID)
}
-func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) {
- return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID)
-}
-
func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) {
return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver)
}
@@ -226,14 +185,6 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error)
return res.RowsAffected()
}
-func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) {
- res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID)
- if err != nil {
- return 0, err
- }
- return res.RowsAffected()
-}
-
func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString
var disappearTimer sql.NullInt64
@@ -242,7 +193,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
&p.BridgeID, &p.ID, &p.Receiver, &mxid,
&parentID, &parentReceiver, &relayLoginID, &otherUserID,
&p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC,
- &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest,
+ &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace,
&p.RoomType, &disappearType, &disappearTimer,
dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata},
)
@@ -257,7 +208,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
}
if disappearType.Valid {
p.Disappear = DisappearingSetting{
- Type: event.DisappearingType(disappearType.String),
+ Type: DisappearingType(disappearType.String),
Timer: time.Duration(disappearTimer.Int64),
}
}
@@ -289,7 +240,7 @@ func (p *Portal) sqlVariables() []any {
p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID),
dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID),
p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC,
- p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest,
+ p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace,
p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer),
dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata},
}
diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go
deleted file mode 100644
index b667399c..00000000
--- a/bridgev2/database/publicmedia.go
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package database
-
-import (
- "context"
- "database/sql"
- "time"
-
- "go.mau.fi/util/dbutil"
-
- "maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/crypto/attachment"
- "maunium.net/go/mautrix/id"
-)
-
-type PublicMediaQuery struct {
- BridgeID networkid.BridgeID
- *dbutil.QueryHelper[*PublicMedia]
-}
-
-type PublicMedia struct {
- BridgeID networkid.BridgeID
- PublicID string
- MXC id.ContentURI
- Keys *attachment.EncryptedFile
- MimeType string
- Expiry time.Time
-}
-
-const (
- upsertPublicMediaQuery = `
- INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry)
- VALUES ($1, $2, $3, $4, $5, $6)
- ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry
- `
- getPublicMediaQuery = `
- SELECT bridge_id, public_id, mxc, keys, mimetype, expiry
- FROM public_media WHERE bridge_id=$1 AND public_id=$2
- `
-)
-
-func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error {
- ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID)
- return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...)
-}
-
-func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) {
- return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID)
-}
-
-func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) {
- var expiry sql.NullInt64
- var mimetype sql.NullString
- err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry)
- if err != nil {
- return nil, err
- }
- if expiry.Valid {
- pm.Expiry = time.Unix(0, expiry.Int64)
- }
- pm.MimeType = mimetype.String
- return pm, nil
-}
-
-func (pm *PublicMedia) sqlVariables() []any {
- return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)}
-}
diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql
index 6092dc24..4eea05bb 100644
--- a/bridgev2/database/upgrades/00-latest.sql
+++ b/bridgev2/database/upgrades/00-latest.sql
@@ -1,4 +1,4 @@
--- v0 -> v27 (compatible with v9+): Latest revision
+-- v0 -> v22 (compatible with v9+): Latest revision
CREATE TABLE "user" (
bridge_id TEXT NOT NULL,
mxid TEXT NOT NULL,
@@ -48,7 +48,6 @@ CREATE TABLE portal (
topic_set BOOLEAN NOT NULL,
name_is_custom BOOLEAN NOT NULL DEFAULT false,
in_space BOOLEAN NOT NULL,
- message_request BOOLEAN NOT NULL DEFAULT false,
room_type TEXT NOT NULL,
disappear_type TEXT,
disappear_timer BIGINT,
@@ -65,7 +64,6 @@ CREATE TABLE portal (
ON DELETE SET NULL ON UPDATE CASCADE
);
CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid);
-CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
CREATE TABLE ghost (
bridge_id TEXT NOT NULL,
@@ -80,7 +78,6 @@ CREATE TABLE ghost (
contact_info_set BOOLEAN NOT NULL,
is_bot BOOLEAN NOT NULL,
identifiers jsonb NOT NULL,
- extra_profile jsonb,
metadata jsonb NOT NULL,
PRIMARY KEY (bridge_id, id)
@@ -130,7 +127,6 @@ CREATE TABLE disappearing_message (
bridge_id TEXT NOT NULL,
mx_room TEXT NOT NULL,
mxid TEXT NOT NULL,
- timestamp BIGINT NOT NULL DEFAULT 0,
type TEXT NOT NULL,
timer BIGINT NOT NULL,
disappear_at BIGINT,
@@ -141,7 +137,6 @@ CREATE TABLE disappearing_message (
REFERENCES portal (bridge_id, mxid)
ON DELETE CASCADE
);
-CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
CREATE TABLE reaction (
bridge_id TEXT NOT NULL,
@@ -220,14 +215,3 @@ CREATE TABLE kv_store (
PRIMARY KEY (bridge_id, key)
);
-
-CREATE TABLE public_media (
- bridge_id TEXT NOT NULL,
- public_id TEXT NOT NULL,
- mxc TEXT NOT NULL,
- keys jsonb,
- mimetype TEXT,
- expiry BIGINT,
-
- PRIMARY KEY (bridge_id, public_id)
-);
diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql
deleted file mode 100644
index ecd00b8d..00000000
--- a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v23 (compatible with v9+): Add event timestamp for disappearing messages
-ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0;
diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql
deleted file mode 100644
index c4290090..00000000
--- a/bridgev2/database/upgrades/24-public-media.sql
+++ /dev/null
@@ -1,11 +0,0 @@
--- v24 (compatible with v9+): Custom URLs for public media
-CREATE TABLE public_media (
- bridge_id TEXT NOT NULL,
- public_id TEXT NOT NULL,
- mxc TEXT NOT NULL,
- keys jsonb,
- mimetype TEXT,
- expiry BIGINT,
-
- PRIMARY KEY (bridge_id, public_id)
-);
diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql
deleted file mode 100644
index b9d82a7a..00000000
--- a/bridgev2/database/upgrades/25-message-requests.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v25 (compatible with v9+): Flag for message request portals
-ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false;
diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql
deleted file mode 100644
index ae5d8cad..00000000
--- a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql
+++ /dev/null
@@ -1,3 +0,0 @@
--- v26 (compatible with v9+): Add room index for disappearing message table and portal parents
-CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
-CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql
deleted file mode 100644
index e8e0549a..00000000
--- a/bridgev2/database/upgrades/27-ghost-extra-profile.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v27 (compatible with v9+): Add column for extra ghost profile metadata
-ALTER TABLE ghost ADD COLUMN extra_profile jsonb;
diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go
index 00ff01c9..9fa6569a 100644
--- a/bridgev2/database/userlogin.go
+++ b/bridgev2/database/userlogin.go
@@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin {
func (u *UserLogin) sqlVariables() []any {
var remoteProfile dbutil.JSON
- if !u.RemoteProfile.IsZero() {
+ if !u.RemoteProfile.IsEmpty() {
remoteProfile.Data = &u.RemoteProfile
}
return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}}
diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go
index e928a4c7..278b236b 100644
--- a/bridgev2/database/userportal.go
+++ b/bridgev2/database/userportal.go
@@ -67,9 +67,6 @@ const (
markLoginAsPreferredQuery = `
UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5
`
- markAllNotInSpaceQuery = `
- UPDATE user_portal SET in_space=false WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3
- `
deleteUserPortalQuery = `
DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5
`
@@ -113,10 +110,6 @@ func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogi
return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver)
}
-func (upq *UserPortalQuery) MarkAllNotInSpace(ctx context.Context, portal networkid.PortalKey) error {
- return upq.Exec(ctx, markAllNotInSpaceQuery, upq.BridgeID, portal.ID, portal.Receiver)
-}
-
func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error {
return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver)
}
diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go
index b5c37e8f..1d063088 100644
--- a/bridgev2/disappear.go
+++ b/bridgev2/disappear.go
@@ -21,7 +21,7 @@ import (
type DisappearLoop struct {
br *Bridge
- nextCheck atomic.Pointer[time.Time]
+ NextCheck time.Time
stop atomic.Pointer[context.CancelFunc]
}
@@ -35,30 +35,15 @@ func (dl *DisappearLoop) Start() {
}
log.Debug().Msg("Disappearing message loop starting")
for {
- nextCheck := time.Now().Add(DisappearCheckInterval)
- dl.nextCheck.Store(&nextCheck)
- const MessageLimit = 200
- messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit)
+ dl.NextCheck = time.Now().Add(DisappearCheckInterval)
+ messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval)
if err != nil {
log.Err(err).Msg("Failed to get upcoming disappearing messages")
} else if len(messages) > 0 {
- if len(messages) >= MessageLimit {
- lastDisappearTime := messages[len(messages)-1].DisappearAt
- log.Debug().
- Int("message_count", len(messages)).
- Time("last_due", lastDisappearTime).
- Msg("Deleting disappearing messages synchronously and checking again immediately")
- // Store the expected next check time to avoid Add spawning unnecessary goroutines.
- // This can be in the past, in which case Add will put everything in the db, which is also fine.
- dl.nextCheck.Store(&lastDisappearTime)
- // If there are many messages, process them synchronously and then check again.
- dl.sleepAndDisappear(ctx, messages...)
- continue
- }
go dl.sleepAndDisappear(ctx, messages...)
}
select {
- case <-time.After(time.Until(dl.GetNextCheck())):
+ case <-time.After(time.Until(dl.NextCheck)):
case <-ctx.Done():
log.Debug().Msg("Disappearing message loop stopping")
return
@@ -66,17 +51,6 @@ func (dl *DisappearLoop) Start() {
}
}
-func (dl *DisappearLoop) GetNextCheck() time.Time {
- if dl == nil {
- return time.Time{}
- }
- nextCheck := dl.nextCheck.Load()
- if nextCheck == nil {
- return time.Time{}
- }
- return *nextCheck
-}
-
func (dl *DisappearLoop) Stop() {
if dl == nil {
return
@@ -86,14 +60,14 @@ func (dl *DisappearLoop) Stop() {
}
}
-func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) {
- startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS)
+func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) {
+ startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages")
return
}
startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool {
- return dm.DisappearAt.After(dl.GetNextCheck())
+ return dm.DisappearAt.After(dl.NextCheck)
})
slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int {
return a.DisappearAt.Compare(b.DisappearAt)
@@ -110,24 +84,17 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa
Stringer("event_id", dm.EventID).
Msg("Failed to save disappearing message")
}
- if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.GetNextCheck()) {
+ if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) {
go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm)
}
}
func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) {
for _, msg := range dms {
- timeUntilDisappear := time.Until(msg.DisappearAt)
- if timeUntilDisappear <= 0 {
- if ctx.Err() != nil {
- return
- }
- } else {
- select {
- case <-time.After(timeUntilDisappear):
- case <-ctx.Done():
- return
- }
+ select {
+ case <-time.After(time.Until(msg.DisappearAt)):
+ case <-ctx.Done():
+ return
}
resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{
Parsed: &event.RedactionEventContent{
diff --git a/bridgev2/errors.go b/bridgev2/errors.go
index f6677d2e..c023dcdf 100644
--- a/bridgev2/errors.go
+++ b/bridgev2/errors.go
@@ -38,53 +38,35 @@ var ErrNotLoggedIn = errors.New("not logged in")
// but direct media is not enabled.
var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled")
-var ErrPortalIsDeleted = errors.New("portal is deleted")
-var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event")
-
// Common message status errors
var (
- ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
- ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
- ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
- ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
- ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
- ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
- ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
- ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
- ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
- ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
- ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
-
- ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
- ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
- ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
-
- ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true)
+ ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
+ ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage()
+ ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage()
+ ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage()
+ ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage()
+ ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage()
+ ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage()
+ ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage()
+ ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage()
+ ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
+ ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage()
+ ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
+ ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
+ ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
+ ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
+ ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
+ ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
+ ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
+ ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
+ ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
+ ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
+ ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
+ ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
+ ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
)
// Common login interface errors
diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go
index 590dd1dc..f06c0363 100644
--- a/bridgev2/ghost.go
+++ b/bridgev2/ghost.go
@@ -9,15 +9,12 @@ package bridgev2
import (
"context"
"crypto/sha256"
- "encoding/json"
"fmt"
- "maps"
"net/http"
- "slices"
"github.com/rs/zerolog"
- "go.mau.fi/util/exerrors"
"go.mau.fi/util/exmime"
+ "golang.org/x/exp/slices"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -137,11 +134,10 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32
}
type UserInfo struct {
- Identifiers []string
- Name *string
- Avatar *Avatar
- IsBot *bool
- ExtraProfile database.ExtraProfile
+ Identifiers []string
+ Name *string
+ Avatar *Avatar
+ IsBot *bool
ExtraUpdates ExtraUpdater[*Ghost]
}
@@ -162,7 +158,7 @@ func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool {
}
func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
- if ghost.AvatarID == avatar.ID && (avatar.Remove || ghost.AvatarMXC != "") && ghost.AvatarSet {
+ if ghost.AvatarID == avatar.ID && ghost.AvatarSet {
return false
}
ghost.AvatarID = avatar.ID
@@ -172,7 +168,7 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
ghost.AvatarSet = false
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar")
return true
- } else if newHash == ghost.AvatarHash && ghost.AvatarMXC != "" && ghost.AvatarSet {
+ } else if newHash == ghost.AvatarHash && ghost.AvatarSet {
return true
}
ghost.AvatarHash = newHash
@@ -189,9 +185,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
return true
}
-func (ghost *Ghost) getExtraProfileMeta() any {
+func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
bridgeName := ghost.Bridge.Network.GetName()
- baseExtra := &event.BeeperProfileExtra{
+ return &event.BeeperProfileExtra{
RemoteID: string(ghost.ID),
Identifiers: ghost.Identifiers,
Service: bridgeName.BeeperBridgeType,
@@ -199,35 +195,23 @@ func (ghost *Ghost) getExtraProfileMeta() any {
IsBridgeBot: false,
IsNetworkBot: ghost.IsBot,
}
- if len(ghost.ExtraProfile) == 0 {
- return baseExtra
- }
- mergedExtra := maps.Clone(ghost.ExtraProfile)
- baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra))
- exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra))
- return mergedExtra
}
-func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool {
- if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta {
- ghost.ContactInfoSet = false
- return false
- }
+func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool {
if identifiers != nil {
slices.Sort(identifiers)
}
- changed := extraProfile.CopyTo(&ghost.ExtraProfile)
+ if ghost.ContactInfoSet &&
+ (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) &&
+ (isBot == nil || *isBot == ghost.IsBot) {
+ return false
+ }
if identifiers != nil {
- changed = changed || !slices.Equal(identifiers, ghost.Identifiers)
ghost.Identifiers = identifiers
}
if isBot != nil {
- changed = changed || *isBot != ghost.IsBot
ghost.IsBot = *isBot
}
- if ghost.ContactInfoSet && !changed {
- return false
- }
err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta())
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata")
@@ -250,7 +234,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool {
}
func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) {
- if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
+ if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
return
}
info, err := source.Client.GetUserInfo(ctx, ghost)
@@ -260,16 +244,12 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin
zerolog.Ctx(ctx).Debug().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
- Bool("has_avatar", ghost.AvatarMXC != "").
- Bool("avatar_set", ghost.AvatarSet).
Msg("Updating ghost info in IfNecessary call")
ghost.UpdateInfo(ctx, info)
} else {
zerolog.Ctx(ctx).Trace().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
- Bool("has_avatar", ghost.AvatarMXC != "").
- Bool("avatar_set", ghost.AvatarSet).
Msg("No ghost info received in IfNecessary call")
}
}
@@ -297,14 +277,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) {
}
if info.Avatar != nil {
update = ghost.UpdateAvatar(ctx, info.Avatar) || update
- } else if oldAvatar == "" && !ghost.AvatarSet {
- // Special case: nil avatar means we're not expecting one ever, if we don't currently have
- // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary.
- ghost.AvatarSet = true
- update = true
}
- if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil {
- update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update
+ if info.Identifiers != nil || info.IsBot != nil {
+ update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update
}
if info.ExtraUpdates != nil {
update = info.ExtraUpdates(ctx, ghost) || update
diff --git a/bridgev2/login.go b/bridgev2/login.go
index b8321719..1fa3afbc 100644
--- a/bridgev2/login.go
+++ b/bridgev2/login.go
@@ -13,7 +13,6 @@ import (
"strings"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/event"
)
// LoginProcess represents a single occurrence of a user logging into the remote network.
@@ -179,8 +178,6 @@ const (
LoginInputFieldTypeToken LoginInputFieldType = "token"
LoginInputFieldTypeURL LoginInputFieldType = "url"
LoginInputFieldTypeDomain LoginInputFieldType = "domain"
- LoginInputFieldTypeSelect LoginInputFieldType = "select"
- LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code"
)
type LoginInputDataField struct {
@@ -192,13 +189,8 @@ type LoginInputDataField struct {
Name string `json:"name"`
// The description of the field shown to the user.
Description string `json:"description"`
- // A default value that the client can pre-fill the field with.
- DefaultValue string `json:"default_value,omitempty"`
// A regex pattern that the client can use to validate input client-side.
Pattern string `json:"pattern,omitempty"`
- // For fields of type select, the valid options.
- // Pattern may also be filled with a regex that matches the same options.
- Options []string `json:"options,omitempty"`
// A function that validates the input and optionally cleans it up before it's submitted to the connector.
Validate func(string) (string, error) `json:"-"`
}
@@ -273,23 +265,6 @@ func (f *LoginInputDataField) FillDefaultValidate() {
type LoginUserInputParams struct {
// The fields that the user needs to fill in.
Fields []LoginInputDataField `json:"fields"`
-
- // Attachments to display alongside the input fields.
- Attachments []*LoginUserInputAttachment `json:"attachments"`
-}
-
-type LoginUserInputAttachment struct {
- Type event.MessageType `json:"type,omitempty"`
- FileName string `json:"filename,omitempty"`
- Content []byte `json:"content,omitempty"`
- Info LoginUserInputAttachmentInfo `json:"info,omitempty"`
-}
-
-type LoginUserInputAttachmentInfo struct {
- MimeType string `json:"mimetype,omitempty"`
- Width int `json:"w,omitempty"`
- Height int `json:"h,omitempty"`
- Size int `json:"size,omitempty"`
}
type LoginCompleteParams struct {
diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go
index 5a2df953..7af2d128 100644
--- a/bridgev2/matrix/connector.go
+++ b/bridgev2/matrix/connector.go
@@ -12,21 +12,20 @@ import (
"encoding/base64"
"errors"
"fmt"
- "net/http"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
+ "unsafe"
+ "github.com/gorilla/mux"
_ "github.com/lib/pq"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
_ "go.mau.fi/util/dbutil/litestream"
- "go.mau.fi/util/exbytes"
"go.mau.fi/util/exsync"
- "go.mau.fi/util/ptr"
"go.mau.fi/util/random"
"golang.org/x/sync/semaphore"
@@ -81,8 +80,6 @@ type Connector struct {
MediaConfig mautrix.RespMediaConfig
SpecVersions *mautrix.RespVersions
- SpecCaps *mautrix.RespCapabilities
- specCapsLock sync.Mutex
Capabilities *bridgev2.MatrixCapabilities
IgnoreUnsupportedServer bool
@@ -104,7 +101,6 @@ type Connector struct {
var (
_ bridgev2.MatrixConnector = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithServer = (*Connector)(nil)
- _ bridgev2.MatrixConnectorWithArbitraryRoomState = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil)
_ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil)
@@ -144,20 +140,13 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) {
br.EventProcessor.On(event.EventReaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent)
- br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent)
br.EventProcessor.On(event.StateMember, br.handleRoomEvent)
br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent)
- br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent)
br.EventProcessor.On(event.StateTopic, br.handleRoomEvent)
- br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent)
- br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent)
- br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent)
- br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent)
br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent)
br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent)
- br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent)
br.Bot = br.AS.BotIntent()
br.Crypto = NewCryptoHelper(br)
br.Bridge.Commands.(*commands.Processor).AddHandlers(
@@ -179,17 +168,6 @@ func (br *Connector) Start(ctx context.Context) error {
if err != nil {
return err
}
- needsStateResync := br.Config.Encryption.Default &&
- br.Bridge.DB.KV.Get(ctx, database.KeyEncryptionStateResynced) != "true"
- if needsStateResync {
- dbExists, err := br.StateStore.TableExists(ctx, "mx_version")
- if err != nil {
- return fmt.Errorf("failed to check if mx_version table exists: %w", err)
- } else if !dbExists {
- needsStateResync = false
- br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true")
- }
- }
err = br.StateStore.Upgrade(ctx)
if err != nil {
return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err}
@@ -233,59 +211,17 @@ func (br *Connector) Start(ctx context.Context) error {
br.wsStopPinger = make(chan struct{}, 1)
go br.websocketServerPinger()
}
- if needsStateResync {
- br.ResyncEncryptionState(ctx)
- }
return nil
}
-func (br *Connector) ResyncEncryptionState(ctx context.Context) {
- log := zerolog.Ctx(ctx)
- roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID])
- rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, `
- SELECT rooms.room_id
- FROM (SELECT DISTINCT(room_id) FROM mx_user_profile WHERE room_id<>'') rooms
- LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id
- WHERE mx_room_state.encryption IS NULL
- `)).AsList()
- if err != nil {
- log.Err(err).Msg("Failed to get room list to resync state")
- return
- }
- var failedCount, successCount, forbiddenCount int
- for _, roomID := range rooms {
- if roomID == "" {
- continue
- }
- var outContent *event.EncryptionEventContent
- err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent)
- if errors.Is(err, mautrix.MForbidden) {
- // Most likely non-existent room
- log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room")
- forbiddenCount++
- } else if err != nil {
- log.Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room")
- failedCount++
- } else {
- successCount++
- }
- }
- br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true")
- log.Info().
- Int("success_count", successCount).
- Int("forbidden_count", forbiddenCount).
- Int("failed_count", failedCount).
- Msg("Resynced rooms")
-}
-
func (br *Connector) GetPublicAddress() string {
if br.Config.AppService.PublicAddress == "https://bridge.example.com" {
return ""
}
- return strings.TrimRight(br.Config.AppService.PublicAddress, "/")
+ return br.Config.AppService.PublicAddress
}
-func (br *Connector) GetRouter() *http.ServeMux {
+func (br *Connector) GetRouter() *mux.Router {
if br.GetPublicAddress() != "" {
return br.AS.Router
}
@@ -344,18 +280,16 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) {
}
func (br *Connector) ensureConnection(ctx context.Context) {
- triedToRegister := false
for {
versions, err := br.Bot.Versions(ctx)
if err != nil {
- if errors.Is(err, mautrix.MForbidden) && !triedToRegister {
+ if errors.Is(err, mautrix.MForbidden) {
br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying")
err = br.Bot.EnsureRegistered(ctx)
if err != nil {
br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN")
os.Exit(16)
}
- triedToRegister = true
} else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) {
br.logInitialRequestError(err, "/versions request failed with auth error")
os.Exit(16)
@@ -368,9 +302,6 @@ func (br *Connector) ensureConnection(ctx context.Context) {
*br.AS.SpecVersions = *versions
br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites)
br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending)
- br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange)
- br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) ||
- (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo)
break
}
}
@@ -415,21 +346,6 @@ func (br *Connector) ensureConnection(ctx context.Context) {
br.Bot.EnsureAppserviceConnection(ctx)
}
-func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities {
- br.specCapsLock.Lock()
- defer br.specCapsLock.Unlock()
- if br.SpecCaps != nil {
- return br.SpecCaps
- }
- caps, err := br.Bot.Capabilities(ctx)
- if err != nil {
- br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver")
- return nil
- }
- br.SpecCaps = caps
- return caps
-}
-
func (br *Connector) fetchMediaConfig(ctx context.Context) {
cfg, err := br.Bot.GetMediaConfig(ctx)
if err != nil {
@@ -495,15 +411,11 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI {
func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error {
if br.Websocket {
br.hasSentAnyStates = true
- return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
+ return br.AS.SendWebsocket(&appservice.WebsocketRequest{
Command: "bridge_status",
Data: state,
})
} else if br.Config.Homeserver.StatusEndpoint != "" {
- // Connecting states aren't really relevant unless the bridge runs somewhere with an unreliable network
- if state.StateEvent == status.StateConnecting {
- return nil
- }
return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken)
} else {
return nil
@@ -521,7 +433,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
log := zerolog.Ctx(ctx)
if !evt.IsSourceEventDoublePuppeted {
- err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{ms.ToCheckpoint(evt)})
+ err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)})
if err != nil {
log.Err(err).Msg("Failed to send message checkpoint")
}
@@ -538,8 +450,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
Msg("Failed to send MSS event")
}
}
- if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice &&
- (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
+ if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
content := ms.ToNoticeEvent(evt)
if editEvent != "" {
content.SetEdit(editEvent)
@@ -567,11 +478,11 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
return ""
}
-func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*status.MessageCheckpoint) error {
+func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error {
checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints}
if br.Websocket {
- return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{
+ return br.AS.SendWebsocket(&appservice.WebsocketRequest{
Command: "message_checkpoint",
Data: checkpointsJSON,
})
@@ -582,7 +493,7 @@ func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*
return nil
}
- return checkpointsJSON.SendHTTP(ctx, br.AS.HTTPClient, endpoint, br.AS.Registration.AppToken)
+ return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken)
}
func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) {
@@ -622,31 +533,6 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve
return br.Bot.PowerLevels(ctx, roomID)
}
-func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) {
- if stateKey == "" {
- switch eventType {
- case event.StateCreate:
- createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID)
- if err != nil || createEvt != nil {
- return createEvt, err
- }
- case event.StateJoinRules:
- joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID)
- if err != nil {
- return nil, err
- } else if joinRulesContent != nil {
- return &event.Event{
- Type: event.StateJoinRules,
- RoomID: roomID,
- StateKey: ptr.Ptr(""),
- Content: event.Content{Parsed: joinRulesContent},
- }, nil
- }
- }
- }
- return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey)
-}
-
func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID)
if err != nil {
@@ -687,7 +573,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr
if intent != nil {
intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp)
}
- if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction {
+ if evt.Type != event.EventEncrypted {
err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content)
if err != nil {
return nil, err
@@ -719,7 +605,7 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid.
eventID[1+hashB64Len] = ':'
copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer)
- return id.EventID(exbytes.UnsafeString(eventID))
+ return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID)))
}
func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID {
diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go
index 7f18f1f5..47226625 100644
--- a/bridgev2/matrix/crypto.go
+++ b/bridgev2/matrix/crypto.go
@@ -24,7 +24,6 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
- "maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
@@ -38,9 +37,9 @@ func init() {
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
-var NoSessionFound = crypto.ErrNoSessionFound
-var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
-var UnknownMessageIndex = olm.ErrUnknownMessageIndex
+var NoSessionFound = crypto.NoSessionFound
+var DuplicateMessageIndex = crypto.DuplicateMessageIndex
+var UnknownMessageIndex = olm.UnknownMessageIndex
type CryptoHelper struct {
bridge *Connector
@@ -136,19 +135,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
return err
}
if isExistingDevice {
- if !helper.verifyKeysAreOnServer(ctx) {
- return nil
- }
- } else {
- err = helper.ShareKeys(ctx)
- if err != nil {
- return fmt.Errorf("failed to share device keys: %w", err)
- }
- }
- if helper.bridge.Config.Encryption.SelfSign {
- if !helper.doSelfSign(ctx) {
- os.Exit(34)
- }
+ helper.verifyKeysAreOnServer(ctx)
}
go helper.resyncEncryptionInfo(context.TODO())
@@ -156,46 +143,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
return nil
}
-func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool {
- log := zerolog.Ctx(ctx)
- hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx)
- if err != nil {
- log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status")
- return false
- }
- log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status")
- keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey)
- if !hasKeys || keyInDB == "overwrite" {
- if keyInDB != "" && keyInDB != "overwrite" {
- log.WithLevel(zerolog.FatalLevel).
- Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.")
- return false
- }
- recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx)
- if recoveryKey != "" {
- helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey)
- }
- if err != nil {
- log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign")
- return false
- }
- log.Info().Msg("Generated new recovery key and self-signed bot device")
- } else if !isVerified {
- if keyInDB == "" {
- log.WithLevel(zerolog.FatalLevel).
- Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.")
- return false
- }
- err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB)
- if err != nil {
- log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key")
- return false
- }
- log.Info().Msg("Verified bot device with existing recovery key")
- }
- return true
-}
-
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
@@ -210,12 +157,12 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
var evt event.EncryptionEventContent
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
if err != nil {
- log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event")
+ log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
_, err = helper.store.DB.Exec(ctx, `
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
`, roomID)
if err != nil {
- log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync")
+ log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync")
}
} else {
maxAge := evt.RotationPeriodMillis
@@ -238,9 +185,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
`, maxAge, maxMessages, roomID)
if err != nil {
- log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table")
+ log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table")
} else {
- log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table")
+ log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table")
}
}
}
@@ -286,7 +233,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
if err != nil {
return nil, false, fmt.Errorf("failed to find existing device ID: %w", err)
} else if len(deviceID) > 0 {
- helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database")
+ helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
}
// Create a new client instance with the default AS settings (including as_token),
// the Login call will then override the access token in the client.
@@ -327,7 +274,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
return client, deviceID != "", nil
}
-func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
+func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
helper.log.Debug().Msg("Making sure keys are still on server")
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
@@ -340,11 +287,10 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
}
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
if ok && len(device.Keys) > 0 {
- return true
+ return
}
helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto")
helper.Reset(ctx, false)
- return false
}
func (helper *CryptoHelper) Start() {
@@ -439,7 +385,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
var encrypted *event.EncryptedEventContent
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
- if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
+ if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
return
}
helper.log.Debug().Err(err).
diff --git a/bridgev2/matrix/cryptostore.go b/bridgev2/matrix/cryptostore.go
index 4c3b5d30..234797a6 100644
--- a/bridgev2/matrix/cryptostore.go
+++ b/bridgev2/matrix/cryptostore.go
@@ -45,7 +45,7 @@ func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context,
WHERE room_id=$1
AND (membership='join' OR membership='invite')
AND user_id<>$2
- AND user_id NOT LIKE $3 ESCAPE '\'
+ AND user_id NOT LIKE $3
`, roomID, store.UserID, store.GhostIDFormat)
if err != nil {
return
diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go
index 0667981a..71c01078 100644
--- a/bridgev2/matrix/directmedia.go
+++ b/bridgev2/matrix/directmedia.go
@@ -39,7 +39,7 @@ func (br *Connector) initDirectMedia() error {
if err != nil {
return fmt.Errorf("failed to initialize media proxy: %w", err)
}
- br.MediaProxy.RegisterRoutes(br.AS.Router, br.Log.With().Str("component", "media proxy").Logger())
+ br.MediaProxy.RegisterRoutes(br.AS.Router)
br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed())
dmn.SetUseDirectMedia()
br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access")
diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go
index f7254bd4..2088d5b1 100644
--- a/bridgev2/matrix/intent.go
+++ b/bridgev2/matrix/intent.go
@@ -9,7 +9,6 @@ package matrix
import (
"bytes"
"context"
- "encoding/json"
"errors"
"fmt"
"io"
@@ -28,7 +27,6 @@ import (
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
"maunium.net/go/mautrix/crypto/attachment"
- "maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
@@ -45,13 +43,13 @@ type ASIntent struct {
var _ bridgev2.MatrixAPI = (*ASIntent)(nil)
var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil)
-var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil)
func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) {
if extra == nil {
extra = &bridgev2.MatrixSendExtra{}
}
- if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) {
+ // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions
+ if eventType == event.EventRedaction {
parsedContent := content.Parsed.(*event.RedactionEventContent)
as.Matrix.AddDoublePuppetValue(content)
return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{
@@ -59,7 +57,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
Extra: content.Raw,
})
}
- if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction {
+ if eventType != event.EventReaction && eventType != event.EventRedaction {
msgContent, ok := content.Parsed.(*event.MessageEventContent)
if ok {
msgContent.AddPerMessageProfileFallback()
@@ -84,27 +82,16 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
eventType = event.EventEncrypted
}
}
- return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()})
-}
-
-func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) {
- if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
- return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
+ if extra.Timestamp.IsZero() {
+ return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content)
+ } else {
+ return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli())
}
- if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
- return nil, fmt.Errorf("failed to check if room is encrypted: %w", err)
- } else if encrypted && as.Connector.Crypto != nil {
- if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil {
- return nil, err
- }
- eventType = event.EventEncrypted
- }
- return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID})
}
func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) {
- targetContent, ok := content.Parsed.(*event.MemberEventContent)
- if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" {
+ targetContent := content.Parsed.(*event.MemberEventContent)
+ if targetContent.Displayname != "" || targetContent.AvatarURL != "" {
return
}
memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID)
@@ -139,7 +126,11 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e
if eventType == event.StateMember {
as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content)
}
- resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()})
+ if ts.IsZero() {
+ resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content)
+ } else {
+ resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli())
+ }
if err != nil && eventType == event.StateMember {
var httpErr mautrix.HTTPError
if errors.As(err, &httpErr) && httpErr.RespError != nil &&
@@ -421,7 +412,6 @@ func (as *ASIntent) UploadMediaStream(
removeAndClose(replFile)
removeAndClose(tempFile)
}
- req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
startedAsyncUpload = true
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
@@ -454,7 +444,6 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn
as.Connector.uploadSema.Release(int64(len(req.ContentBytes)))
}
}
- req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
if resp != nil {
@@ -486,62 +475,11 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr
return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL)
}
-func dataToFields(data any) (map[string]json.RawMessage, error) {
- fields, ok := data.(map[string]json.RawMessage)
- if ok {
- return fields, nil
- }
- d, err := json.Marshal(data)
- if err != nil {
- return nil, err
- }
- d = canonicaljson.CanonicalJSONAssumeValid(d)
- err = json.Unmarshal(d, &fields)
- return fields, err
-}
-
-func marshalField(val any) json.RawMessage {
- data, _ := json.Marshal(val)
- if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
- return canonicaljson.CanonicalJSONAssumeValid(data)
- }
- return data
-}
-
-var nullJSON = json.RawMessage("null")
-
func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
- if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
- return as.Matrix.BeeperUpdateProfile(ctx, data)
- } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo {
- fields, err := dataToFields(data)
- if err != nil {
- return fmt.Errorf("failed to marshal fields: %w", err)
- }
- currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID)
- if err != nil {
- return fmt.Errorf("failed to get current profile: %w", err)
- }
- for key, val := range fields {
- existing, ok := currentProfile.Extra[key]
- if !ok {
- if bytes.Equal(val, nullJSON) {
- continue
- }
- err = as.Matrix.SetProfileField(ctx, key, val)
- } else if !bytes.Equal(marshalField(existing), val) {
- if bytes.Equal(val, nullJSON) {
- err = as.Matrix.DeleteProfileField(ctx, key)
- } else {
- err = as.Matrix.SetProfileField(ctx, key, val)
- }
- }
- if err != nil {
- return fmt.Errorf("failed to set profile field %q: %w", key, err)
- }
- }
+ if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
+ return nil
}
- return nil
+ return as.Matrix.BeeperUpdateProfile(ctx, data)
}
func (as *ASIntent) GetMXID() id.UserID {
@@ -552,12 +490,8 @@ func (as *ASIntent) IsDoublePuppet() bool {
return as.Matrix.IsDoublePuppet()
}
-func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...bridgev2.EnsureJoinedParams) error {
- var params bridgev2.EnsureJoinedParams
- if len(extra) > 0 {
- params = extra[0]
- }
- err := as.Matrix.EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{Via: params.Via})
+func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error {
+ err := as.Matrix.EnsureJoined(ctx, roomID)
if err != nil {
return err
}
@@ -583,39 +517,6 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent {
return content
}
-func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) {
- if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
- // Hungryserv doesn't override the capabilities endpoint nor do room versions
- return
- }
- caps := as.Connector.fetchCapabilities(ctx)
- roomVer := req.RoomVersion
- if roomVer == "" && caps != nil && caps.RoomVersions != nil {
- roomVer = id.RoomVersion(caps.RoomVersions.Default)
- }
- if roomVer != "" && !roomVer.PrivilegedRoomCreators() {
- return
- }
- creators, _ := req.CreationContent["additional_creators"].([]id.UserID)
- creators = append(slices.Clone(creators), as.GetMXID())
- if req.PowerLevelOverride != nil {
- for _, creator := range creators {
- delete(req.PowerLevelOverride.Users, creator)
- }
- }
- for _, evt := range req.InitialState {
- if evt.Type != event.StatePowerLevels {
- continue
- }
- content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
- if ok {
- for _, creator := range creators {
- delete(content.Users, creator)
- }
- }
- }
-}
-
func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) {
if as.Connector.Config.Encryption.Default {
req.InitialState = append(req.InitialState, &event.Event{
@@ -631,7 +532,6 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom)
}
req.CreationContent["m.federate"] = false
}
- as.filterCreateRequestForV12(ctx, req)
resp, err := as.Matrix.CreateRoom(ctx, req)
if err != nil {
return "", err
@@ -673,19 +573,8 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id.
}
func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error {
- if roomID == "" {
- return nil
- }
if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) {
- err := as.Matrix.BeeperDeleteRoom(ctx, roomID)
- if err != nil {
- return err
- }
- err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal")
- }
- return nil
+ return as.Matrix.BeeperDeleteRoom(ctx, roomID)
}
members, err := as.Matrix.JoinedMembers(ctx, roomID)
if err != nil {
@@ -773,23 +662,3 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T
})
}
}
-
-func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) {
- evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID)
- if err != nil {
- return nil, err
- }
- err = evt.Content.ParseRaw(evt.Type)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content")
- }
-
- if evt.Type == event.EventEncrypted {
- if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt {
- return nil, errors.New("can't decrypt the event")
- }
- return as.Connector.Crypto.Decrypt(ctx, evt)
- }
-
- return evt, nil
-}
diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go
index 954d0ad9..84e85d24 100644
--- a/bridgev2/matrix/matrix.go
+++ b/bridgev2/matrix/matrix.go
@@ -27,11 +27,6 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) {
if br.shouldIgnoreEvent(evt) {
return
}
- if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember {
- zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events")
- br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
- return
- }
if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require {
zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required")
br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true)
@@ -68,10 +63,6 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event)
case event.EphemeralEventTyping:
typingContent := evt.Content.AsTyping()
typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser)
- case event.BeeperEphemeralEventAIStream:
- if br.shouldIgnoreEvent(evt) {
- return
- }
}
br.Bridge.QueueMatrixEvent(ctx, evt)
}
@@ -85,11 +76,6 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
- if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents {
- log.Debug().Msg("Dropping event from user with no permission to send events")
- br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
- return
- }
ctx = log.WithContext(ctx)
if br.Crypto == nil {
br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true)
@@ -101,18 +87,17 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
decryptionStart := time.Now()
decrypted, err := br.Crypto.Decrypt(ctx, evt)
decryptionRetryCount := 0
- var errorEventID id.EventID
if errors.Is(err, NoSessionFound) {
decryptionRetryCount = 1
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
- go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false)
+ go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false)
if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = br.Crypto.Decrypt(ctx, evt)
} else {
- go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID)
+ go br.waitLongerForSession(ctx, evt, decryptionStart)
return
}
}
@@ -121,18 +106,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true)
return
}
- br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart))
+ br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart))
}
-func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) {
+func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) {
log := zerolog.Ctx(ctx)
content := evt.Content.AsEncrypted()
log.Debug().
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, requesting keys and waiting longer...")
- //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
+ var errorEventID *id.EventID
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)
if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@@ -157,7 +142,7 @@ type CommandProcessor interface {
}
func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) {
- err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{{
+ err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{{
RoomID: evt.RoomID,
EventID: evt.ID,
EventType: evt.Type,
@@ -184,7 +169,7 @@ func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool {
}
func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool {
- if br.shouldIgnoreEventFromUser(evt.Sender) && evt.Type != event.StateTombstone {
+ if br.shouldIgnoreEventFromUser(evt.Sender) {
return true
}
dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey]
@@ -235,6 +220,7 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event
go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount)
decrypted.Mautrix.CheckpointSent = true
decrypted.Mautrix.DecryptionDuration = duration
+ decrypted.Mautrix.EventSource |= event.SourceDecrypted
br.EventProcessor.Dispatch(ctx, decrypted)
if errorEventID != nil && *errorEventID != "" {
_, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID)
diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go
index f5e438de..0f6aa68c 100644
--- a/bridgev2/matrix/mxmain/dberror.go
+++ b/bridgev2/matrix/mxmain/dberror.go
@@ -66,12 +66,7 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s
} else if errors.Is(err, dbutil.ErrForeignTables) {
br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info")
} else if errors.Is(err, dbutil.ErrNotOwned) {
- var noe dbutil.NotOwnedError
- if errors.As(err, &noe) && noe.Owner == br.Name {
- br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?")
- } else {
- br.Log.Info().Msg("Sharing the same database with different programs is not supported")
- }
+ br.Log.Info().Msg("Sharing the same database with different programs is not supported")
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
br.Log.Info().Msg("Downgrading the bridge is not supported")
}
diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go
deleted file mode 100644
index 1b4f1467..00000000
--- a/bridgev2/matrix/mxmain/envconfig.go
+++ /dev/null
@@ -1,161 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package mxmain
-
-import (
- "fmt"
- "iter"
- "os"
- "reflect"
- "strconv"
- "strings"
-
- "go.mau.fi/util/random"
-)
-
-var randomParseFilePrefix = random.String(16) + "READFILE:"
-
-func parseEnv(prefix string) iter.Seq2[[]string, string] {
- return func(yield func([]string, string) bool) {
- for _, s := range os.Environ() {
- if !strings.HasPrefix(s, prefix) {
- continue
- }
- kv := strings.SplitN(s, "=", 2)
- key := strings.TrimPrefix(kv[0], prefix)
- value := kv[1]
- if strings.HasSuffix(key, "_FILE") {
- key = strings.TrimSuffix(key, "_FILE")
- value = randomParseFilePrefix + value
- }
- key = strings.ToLower(key)
- if !strings.ContainsRune(key, '.') {
- key = strings.ReplaceAll(key, "__", ".")
- }
- if !yield(strings.Split(key, "."), value) {
- return
- }
- }
- }
-}
-
-func reflectYAMLFieldName(f *reflect.StructField) string {
- parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2)
- fieldName := parts[0]
- if fieldName == "-" && len(parts) == 1 {
- return ""
- }
- if fieldName == "" {
- return strings.ToLower(f.Name)
- }
- return fieldName
-}
-
-type reflectGetResult struct {
- val reflect.Value
- valKind reflect.Kind
- remainingPath []string
-}
-
-func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) {
- if len(path) == 0 {
- return &reflectGetResult{val: rv, valKind: rv.Kind()}, true
- }
- if rv.Kind() == reflect.Ptr {
- rv = rv.Elem()
- }
- switch rv.Kind() {
- case reflect.Map:
- return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true
- case reflect.Struct:
- fields := reflect.VisibleFields(rv.Type())
- for _, field := range fields {
- fieldName := reflectYAMLFieldName(&field)
- if fieldName != "" && fieldName == path[0] {
- return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:])
- }
- }
- default:
- }
- return nil, false
-}
-
-func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) {
- if len(path) > 0 && path[0] == "network" {
- return reflectGetYAML(network, path[1:])
- }
- return reflectGetYAML(main, path)
-}
-
-func formatKeyString(key []string) string {
- return strings.Join(key, "->")
-}
-
-func UpdateConfigFromEnv(cfg, networkData any, prefix string) error {
- cfgVal := reflect.ValueOf(cfg)
- networkVal := reflect.ValueOf(networkData)
- for key, value := range parseEnv(prefix) {
- field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key)
- if !ok {
- return fmt.Errorf("%s not found", formatKeyString(key))
- }
- if strings.HasPrefix(value, randomParseFilePrefix) {
- filepath := strings.TrimPrefix(value, randomParseFilePrefix)
- fileData, err := os.ReadFile(filepath)
- if err != nil {
- return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err)
- }
- value = strings.TrimSpace(string(fileData))
- }
- var parsedVal any
- var err error
- switch field.valKind {
- case reflect.String:
- parsedVal = value
- case reflect.Bool:
- parsedVal, err = strconv.ParseBool(value)
- if err != nil {
- return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
- }
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- parsedVal, err = strconv.ParseInt(value, 10, 64)
- if err != nil {
- return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
- }
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- parsedVal, err = strconv.ParseUint(value, 10, 64)
- if err != nil {
- return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
- }
- case reflect.Float32, reflect.Float64:
- parsedVal, err = strconv.ParseFloat(value, 64)
- if err != nil {
- return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
- }
- default:
- return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key))
- }
- if field.val.Kind() == reflect.Ptr {
- if field.val.IsNil() {
- field.val.Set(reflect.New(field.val.Type().Elem()))
- }
- field.val = field.val.Elem()
- }
- if field.val.Kind() == reflect.Map {
- key = key[:len(key)-len(field.remainingPath)]
- mapKeyStr := strings.Join(field.remainingPath, ".")
- key = append(key, mapKeyStr)
- if field.val.Type().Key().Kind() != reflect.String {
- return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key))
- }
- field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal))
- } else {
- field.val.Set(reflect.ValueOf(parsedVal))
- }
- }
- return nil
-}
diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml
index ccc81c4b..48e0d528 100644
--- a/bridgev2/matrix/mxmain/example-config.yaml
+++ b/bridgev2/matrix/mxmain/example-config.yaml
@@ -15,7 +15,6 @@ bridge:
# By default, users who are in the same group on the remote network will be
# in the same Matrix room bridged to that group. If this is set to true,
# every user will get their own Matrix room instead.
- # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST.
split_portals: false
# Should the bridge resend `m.bridge` events to all portals on startup?
resend_bridge_info: false
@@ -29,9 +28,6 @@ bridge:
# How long after an unknown error should the bridge attempt a full reconnect?
# Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value.
unknown_error_auto_reconnect: null
- # Maximum number of times to do the auto-reconnect above.
- # The counter is per login, but is never reset except on logout and restart.
- unknown_error_max_auto_reconnects: 10
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
bridge_matrix_leave: false
@@ -50,11 +46,6 @@ bridge:
# Should cross-room reply metadata be bridged?
# Most Matrix clients don't support this and servers may reject such messages too.
cross_room_replies: false
- # If a state event fails to bridge, should the bridge revert any state changes made by that event?
- revert_failed_state_changes: false
- # In portals with no relay set, should Matrix users be kicked if they're
- # not logged into an account that's in the remote chat?
- kick_matrix_users: true
# What should be done to portal rooms when a user logs out or is logged out?
# Permitted values:
@@ -244,9 +235,6 @@ matrix:
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
# rather than keeping the whole file in memory.
upload_file_threshold: 5242880
- # Should the bridge set additional custom profile info for ghosts?
- # This can make a lot of requests, as there's no batch profile update endpoint.
- ghost_extra_profile_info: false
# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors.
analytics:
@@ -259,8 +247,10 @@ analytics:
# Settings for provisioning API
provisioning:
+ # Prefix for the provisioning API paths.
+ prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate" or null, a random secret will be generated,
- # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters.
+ # or if set to "disable", the provisioning API will be disabled.
shared_secret: generate
# Whether to allow provisioning API requests to be authed using Matrix access tokens.
# This follows the same rules as double puppeting to determine which server to contact to check the token,
@@ -286,14 +276,6 @@ public_media:
expiry: 0
# Length of hash to use for public media URLs. Must be between 0 and 32.
hash_length: 32
- # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served.
- # If you change this, you must configure your reverse proxy to rewrite the path accordingly.
- path_prefix: /_mautrix/publicmedia
- # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs?
- # If false, the generated URLs will just have the MXC URI and a HMAC signature.
- # The hash_length field will be used to decide the length of the generated URL.
- # This also allows invalidating URLs by deleting the database entry.
- use_database: false
# Settings for converting remote media to custom mxc:// URIs instead of reuploading.
# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html
@@ -384,12 +366,6 @@ encryption:
# Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861).
# Changing this option requires updating the appservice registration file.
msc4190: false
- # Whether to encrypt reactions and reply metadata as per MSC4392.
- msc4392: false
- # Should the bridge bot generate a recovery key and cross-signing keys and verify itself?
- # Note that without the latest version of MSC4190, this will fail if you reset the bridge database.
- # The generated recovery key will be saved in the kv_store table under `recovery_key`.
- self_sign: false
# Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
# You must use a client that supports requesting keys from other users to use this feature.
allow_key_sharing: true
@@ -452,16 +428,6 @@ encryption:
# You should not enable this option unless you understand all the implications.
disable_device_change_key_rotation: false
-# Prefix for environment variables. All variables with this prefix must map to valid config fields.
-# Nesting in variable names is represented with a dot (.).
-# If there are no dots in the name, two underscores (__) are replaced with a dot.
-#
-# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token.
-# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily.
-#
-# If this is null, reading config fields from environment will be disabled.
-env_config_prefix: null
-
# Logging config. See https://github.com/tulir/zeroconfig for details.
logging:
min_level: debug
diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go
index 97cdeddf..c8eb820b 100644
--- a/bridgev2/matrix/mxmain/legacymigrate.go
+++ b/bridgev2/matrix/mxmain/legacymigrate.go
@@ -135,10 +135,7 @@ func (br *BridgeMain) CheckLegacyDB(
}
var dbVersion int
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
- if err != nil {
- log.Fatal().Err(err).Msg("Failed to get database version")
- return
- } else if dbVersion < expectedVersion {
+ if dbVersion < expectedVersion {
log.Fatal().
Int("expected_version", expectedVersion).
Int("version", dbVersion).
diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go
index 1e8b51d1..e6219c50 100644
--- a/bridgev2/matrix/mxmain/main.go
+++ b/bridgev2/matrix/mxmain/main.go
@@ -26,7 +26,6 @@ import (
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
- "go.mau.fi/util/progver"
"gopkg.in/yaml.v3"
flag "maunium.net/go/mauflag"
@@ -63,9 +62,6 @@ type BridgeMain struct {
// git tag to see if the built version is the release or a dev build.
// You can either bump this right after a release or right before, as long as it matches on the release commit.
Version string
- // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning,
- // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch.
- SemCalVer bool
// PostInit is a function that will be called after the bridge has been initialized but before it is started.
PostInit func()
@@ -90,7 +86,11 @@ type BridgeMain struct {
RegistrationPath string
SaveConfig bool
- ver progver.ProgramVersion
+ baseVersion string
+ commit string
+ LinkifiedVersion string
+ VersionDesc string
+ BuildTime time.Time
AdditionalShortFlags string
AdditionalLongFlags string
@@ -99,7 +99,14 @@ type BridgeMain struct {
}
type VersionJSONOutput struct {
- progver.ProgramVersion
+ Name string
+ URL string
+
+ Version string
+ IsRelease bool
+ Commit string
+ FormattedVersion string
+ BuildTime time.Time
OS string
Arch string
@@ -140,11 +147,18 @@ func (br *BridgeMain) PreInit() {
flag.PrintHelp()
os.Exit(0)
} else if *version {
- fmt.Println(br.ver.VersionDescription)
+ fmt.Println(br.VersionDesc)
os.Exit(0)
} else if *versionJSON {
output := VersionJSONOutput{
- ProgramVersion: br.ver,
+ URL: br.URL,
+ Name: br.Name,
+
+ Version: br.baseVersion,
+ IsRelease: br.Version == br.baseVersion,
+ Commit: br.commit,
+ FormattedVersion: br.Version,
+ BuildTime: br.BuildTime,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
@@ -226,8 +240,8 @@ func (br *BridgeMain) Init() {
br.Log.Info().
Str("name", br.Name).
- Str("version", br.ver.FormattedVersion).
- Time("built_at", br.ver.BuildTime).
+ Str("version", br.Version).
+ Time("built_at", br.BuildTime).
Str("go_version", runtime.Version()).
Msg("Initializing bridge")
@@ -241,7 +255,7 @@ func (br *BridgeMain) Init() {
br.Matrix.AS.DoublePuppetValue = br.Name
br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{
Func: func(ce *commands.Event) {
- ce.Reply(br.ver.MarkdownDescription())
+ ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123))
},
Name: "version",
Help: commands.HelpMeta{
@@ -354,13 +368,6 @@ func (br *BridgeMain) LoadConfig() {
}
}
cfg.Bridge.Backfill = cfg.Backfill
- if cfg.EnvConfigPrefix != "" {
- err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix)
- if err != nil {
- _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err)
- os.Exit(10)
- }
- }
br.Config = &cfg
}
@@ -439,12 +446,42 @@ func (br *BridgeMain) Stop() {
//
// (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`)
func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) {
- br.ver = progver.ProgramVersion{
- Name: br.Name,
- URL: br.URL,
- BaseVersion: br.Version,
- SemCalVer: br.SemCalVer,
- }.Init(tag, commit, rawBuildTime)
- mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent)
- br.Version = br.ver.FormattedVersion
+ br.baseVersion = br.Version
+ if len(tag) > 0 && tag[0] == 'v' {
+ tag = tag[1:]
+ }
+ if tag != br.Version {
+ suffix := ""
+ if !strings.HasSuffix(br.Version, "+dev") {
+ suffix = "+dev"
+ }
+ if len(commit) > 8 {
+ br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8])
+ } else {
+ br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix)
+ }
+ }
+
+ br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version)
+ if tag == br.Version {
+ br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag)
+ } else if len(commit) > 8 {
+ br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1)
+ }
+ var buildTime time.Time
+ if rawBuildTime != "unknown" {
+ buildTime, _ = time.Parse(time.RFC3339, rawBuildTime)
+ }
+ var builtWith string
+ if buildTime.IsZero() {
+ rawBuildTime = "unknown"
+ builtWith = runtime.Version()
+ } else {
+ rawBuildTime = buildTime.Format(time.RFC1123)
+ builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version())
+ }
+ mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent)
+ br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith)
+ br.commit = commit
+ br.BuildTime = buildTime
}
diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go
index 243b91da..f865a19e 100644
--- a/bridgev2/matrix/provisioning.go
+++ b/bridgev2/matrix/provisioning.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -17,20 +17,18 @@ import (
"sync"
"time"
+ "github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
- "go.mau.fi/util/exerrors"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exstrings"
"go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
"go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/bridgev2/provisionutil"
"maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/federation"
"maunium.net/go/mautrix/id"
@@ -42,7 +40,7 @@ type matrixAuthCacheEntry struct {
}
type ProvisioningAPI struct {
- Router *http.ServeMux
+ Router *mux.Router
br *Connector
log zerolog.Logger
@@ -85,18 +83,24 @@ const (
provisioningUserKey provisioningContextKey = iota
provisioningUserLoginKey
provisioningLoginProcessKey
- ProvisioningKeyRequest
)
+const ProvisioningKeyRequest = "fi.mau.provision.request"
+
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}
-func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
+func (prov *ProvisioningAPI) GetRouter() *mux.Router {
return prov.Router
}
-func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI {
+type IProvisioningAPI interface {
+ GetRouter() *mux.Router
+ GetUser(r *http.Request) *bridgev2.User
+}
+
+func (br *Connector) GetProvisioning() IProvisioningAPI {
return br.Provisioning
}
@@ -112,57 +116,41 @@ func (prov *ProvisioningAPI) Init() {
tp.Dialer.Timeout = 10 * time.Second
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
- prov.Router = http.NewServeMux()
- prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami)
- prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities)
- prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows)
- prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
- prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep)
- prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
- prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins)
- prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList)
- prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
- prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
- prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
- prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup)
+ prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
+ prov.Router.Use(hlog.NewHandler(prov.log))
+ prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
+ prov.Router.Use(exhttp.CORSMiddleware)
+ prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}))
+ prov.Router.Use(prov.AuthMiddleware)
+ prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
+ prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
+ prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
+ prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
+ prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
+ prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
+ prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
+ prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
+ prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers)
+ prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
+ prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
+ prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
if prov.br.Config.Provisioning.EnableSessionTransfers {
prov.log.Debug().Msg("Enabling session transfer API")
- prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer)
- prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer)
+ prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer)
+ prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer)
}
if prov.br.Config.Provisioning.DebugEndpoints {
prov.log.Debug().Msg("Enabling debug API at /debug")
- debugRouter := http.NewServeMux()
- debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline)
- debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile)
- debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol)
- debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace)
- debugRouter.HandleFunc("/pprof/", pprof.Index)
- prov.br.AS.Router.Handle("/debug/", exhttp.ApplyMiddleware(
- debugRouter,
- exhttp.StripPrefix("/debug"),
- hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()),
- requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
- prov.DebugAuthMiddleware,
- ))
+ r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
+ r.Use(prov.DebugAuthMiddleware)
+ r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet)
+ r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet)
+ r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet)
+ r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet)
+ r.PathPrefix("/pprof/").HandlerFunc(pprof.Index)
}
-
- errorBodies := exhttp.ErrorBodies{
- NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
- MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
- }
- prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware(
- prov.Router,
- exhttp.StripPrefix("/_matrix/provision"),
- hlog.NewHandler(prov.log),
- hlog.RequestIDHandler("request_id", "Request-Id"),
- exhttp.CORSMiddleware,
- requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
- exhttp.HandleErrors(errorBodies),
- prov.AuthMiddleware,
- ))
}
func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error {
@@ -206,20 +194,12 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI
}
}
-func disabledAuth(w http.ResponseWriter, r *http.Request) {
- mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w)
-}
-
func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
- secret := prov.br.Config.Provisioning.SharedSecret
- if len(secret) < 16 {
- return http.HandlerFunc(disabledAuth)
- }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" {
mautrix.MMissingToken.WithMessage("Missing auth token").Write(w)
- } else if !exstrings.ConstantTimeEqual(auth, secret) {
+ } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w)
} else {
h.ServeHTTP(w, r)
@@ -228,10 +208,6 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
- secret := prov.br.Config.Provisioning.SharedSecret
- if len(secret) < 16 {
- return http.HandlerFunc(disabledAuth)
- }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" && prov.GetAuthFromRequest != nil {
@@ -245,7 +221,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
if userID == "" && prov.GetUserIDFromRequest != nil {
userID = prov.GetUserIDFromRequest(r)
}
- if !exstrings.ConstantTimeEqual(auth, secret) {
+ if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
var err error
if strings.HasPrefix(auth, "openid:") {
err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:"))
@@ -274,6 +250,38 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r)
ctx = context.WithValue(ctx, provisioningUserKey, user)
+ if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
+ prov.loginsLock.RLock()
+ login, ok := prov.logins[loginID]
+ prov.loginsLock.RUnlock()
+ if !ok {
+ zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found")
+ mautrix.MNotFound.WithMessage("Login not found").Write(w)
+ return
+ }
+ login.Lock.Lock()
+ // This will only unlock after the handler runs
+ defer login.Lock.Unlock()
+ stepID := mux.Vars(r)["stepID"]
+ if login.NextStep.StepID != stepID {
+ zerolog.Ctx(r.Context()).Warn().
+ Str("request_step_id", stepID).
+ Str("expected_step_id", login.NextStep.StepID).
+ Msg("Step ID does not match")
+ mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
+ return
+ }
+ stepType := mux.Vars(r)["stepType"]
+ if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
+ zerolog.Ctx(r.Context()).Warn().
+ Str("request_step_type", stepType).
+ Str("expected_step_type", string(login.NextStep.Type)).
+ Msg("Step type does not match")
+ mautrix.MBadState.WithMessage("Step type does not match").Write(w)
+ return
+ }
+ ctx = context.WithValue(ctx, provisioningLoginProcessKey, login)
+ }
h.ServeHTTP(w, r.WithContext(ctx))
})
}
@@ -324,7 +332,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) {
prevState.UserID = ""
prevState.RemoteID = ""
prevState.RemoteName = ""
- prevState.RemoteProfile = status.RemoteProfile{}
+ prevState.RemoteProfile = nil
resp.Logins[i] = RespWhoamiLogin{
StateEvent: prevState.StateEvent,
StateTS: prevState.Timestamp,
@@ -356,24 +364,18 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques
})
}
-func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) {
- exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning)
-}
-
var ErrNilStep = errors.New("bridge returned nil step with no error")
-var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"}
func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) {
overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r)
if failed {
return
}
- user := prov.GetUser(r)
- if overrideLogin == nil && user.HasTooManyLogins() {
- ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w)
- return
- }
- login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID"))
+ login, err := prov.net.CreateLogin(
+ r.Context(),
+ prov.GetUser(r),
+ mux.Vars(r)["flowID"],
+ )
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
RespondWithError(w, err, "Internal error creating login process")
@@ -403,18 +405,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
Override: overrideLogin,
}
prov.loginsLock.Unlock()
- zerolog.Ctx(r.Context()).Info().
- Any("first_step", firstStep).
- Msg("Created login process")
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
}
func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) {
- zerolog.Ctx(ctx).Info().
- Str("step_id", step.StepID).
- Str("user_login_id", string(step.CompleteParams.UserLoginID)).
- Msg("Login completed successfully")
- prov.deleteLogin(login, false)
if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID {
return
}
@@ -428,61 +422,6 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
}, bridgev2.DeleteOpts{LogoutRemote: true})
}
-func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) {
- if cancel {
- login.Process.Cancel()
- }
- prov.loginsLock.Lock()
- delete(prov.logins, login.ID)
- prov.loginsLock.Unlock()
-}
-
-func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) {
- loginID := r.PathValue("loginProcessID")
- prov.loginsLock.RLock()
- login, ok := prov.logins[loginID]
- prov.loginsLock.RUnlock()
- if !ok {
- zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found")
- mautrix.MNotFound.WithMessage("Login not found").Write(w)
- return
- }
- login.Lock.Lock()
- // This will only unlock after the handler runs
- defer login.Lock.Unlock()
- stepID := r.PathValue("stepID")
- if login.NextStep.StepID != stepID {
- zerolog.Ctx(r.Context()).Warn().
- Str("request_step_id", stepID).
- Str("expected_step_id", login.NextStep.StepID).
- Msg("Step ID does not match")
- mautrix.MBadState.WithMessage("Step ID does not match").Write(w)
- return
- }
- stepType := r.PathValue("stepType")
- if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
- zerolog.Ctx(r.Context()).Warn().
- Str("request_step_type", stepType).
- Str("expected_step_type", string(login.NextStep.Type)).
- Msg("Step type does not match")
- mautrix.MBadState.WithMessage("Step type does not match").Write(w)
- return
- }
- ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login)
- r = r.WithContext(ctx)
- switch bridgev2.LoginStepType(r.PathValue("stepType")) {
- case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies:
- prov.PostLoginSubmitInput(w, r)
- case bridgev2.LoginStepTypeDisplayAndWait:
- prov.PostLoginWait(w, r)
- case bridgev2.LoginStepTypeComplete:
- fallthrough
- default:
- // This is probably impossible because of the above check that the next step type matches the request.
- mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w)
- }
-}
-
func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) {
var params map[string]string
err := json.NewDecoder(r.Body).Decode(¶ms)
@@ -507,14 +446,11 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input")
RespondWithError(w, err, "Internal error submitting input")
- prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
- } else {
- zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
@@ -528,21 +464,18 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait")
RespondWithError(w, err, "Internal error waiting for login")
- prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
- } else {
- zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
user := prov.GetUser(r)
- userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
+ userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
if userLoginID == "all" {
for {
login := user.GetDefaultLogin()
@@ -619,23 +552,115 @@ func RespondWithError(w http.ResponseWriter, err error, message string) {
}
}
+type RespResolveIdentifier struct {
+ ID networkid.UserID `json:"id"`
+ Name string `json:"name,omitempty"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Identifiers []string `json:"identifiers,omitempty"`
+ MXID id.UserID `json:"mxid,omitempty"`
+ DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
+}
+
func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) {
login := prov.GetLoginForRequest(w, r)
if login == nil {
return
}
- resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat)
+ api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
+ if !ok {
+ mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
+ return
+ }
+ resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
if err != nil {
+ zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
RespondWithError(w, err, "Internal error resolving identifier")
+ return
} else if resp == nil {
mautrix.MNotFound.WithMessage("Identifier not found").Write(w)
- } else {
- status := http.StatusOK
- if resp.JustCreated {
- status = http.StatusCreated
- }
- exhttp.WriteJSONResponse(w, status, resp)
+ return
}
+ apiResp := &RespResolveIdentifier{
+ ID: resp.UserID,
+ }
+ status := http.StatusOK
+ if resp.Ghost != nil {
+ if resp.UserInfo != nil {
+ resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo)
+ }
+ apiResp.Name = resp.Ghost.Name
+ apiResp.AvatarURL = resp.Ghost.AvatarMXC
+ apiResp.Identifiers = resp.Ghost.Identifiers
+ apiResp.MXID = resp.Ghost.Intent.GetMXID()
+ } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
+ apiResp.Name = *resp.UserInfo.Name
+ }
+ if resp.Chat != nil {
+ if resp.Chat.Portal == nil {
+ resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey)
+ if err != nil {
+ zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal")
+ mautrix.MUnknown.WithMessage("Failed to get portal").Write(w)
+ return
+ }
+ }
+ if createChat && resp.Chat.Portal.MXID == "" {
+ status = http.StatusCreated
+ err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo)
+ if err != nil {
+ zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room")
+ mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w)
+ return
+ }
+ }
+ apiResp.DMRoomID = resp.Chat.Portal.MXID
+ }
+ exhttp.WriteJSONResponse(w, status, apiResp)
+}
+
+type RespGetContactList struct {
+ Contacts []*RespResolveIdentifier `json:"contacts"`
+}
+
+func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) {
+ apiResp = make([]*RespResolveIdentifier, len(resp))
+ for i, contact := range resp {
+ apiContact := &RespResolveIdentifier{
+ ID: contact.UserID,
+ }
+ apiResp[i] = apiContact
+ if contact.UserInfo != nil {
+ if contact.UserInfo.Name != nil {
+ apiContact.Name = *contact.UserInfo.Name
+ }
+ if contact.UserInfo.Identifiers != nil {
+ apiContact.Identifiers = contact.UserInfo.Identifiers
+ }
+ }
+ if contact.Ghost != nil {
+ if contact.Ghost.Name != "" {
+ apiContact.Name = contact.Ghost.Name
+ }
+ if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
+ apiContact.Identifiers = contact.Ghost.Identifiers
+ }
+ apiContact.AvatarURL = contact.Ghost.AvatarMXC
+ apiContact.MXID = contact.Ghost.Intent.GetMXID()
+ }
+ if contact.Chat != nil {
+ if contact.Chat.Portal == nil {
+ var err error
+ contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
+ }
+ }
+ if contact.Chat.Portal != nil {
+ apiContact.DMRoomID = contact.Chat.Portal.MXID
+ }
+ }
+ }
+ return
}
func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) {
@@ -643,18 +668,30 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque
if login == nil {
return
}
- resp, err := provisionutil.GetContactList(r.Context(), login)
- if err != nil {
- RespondWithError(w, err, "Internal error getting contact list")
+ api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
+ if !ok {
+ mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w)
return
}
- exhttp.WriteJSONResponse(w, http.StatusOK, resp)
+ resp, err := api.GetContactList(r.Context())
+ if err != nil {
+ zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
+ RespondWithError(w, err, "Internal error fetching contact list")
+ return
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{
+ Contacts: prov.processResolveIdentifiers(r.Context(), resp),
+ })
}
type ReqSearchUsers struct {
Query string `json:"query"`
}
+type RespSearchUsers struct {
+ Results []*RespResolveIdentifier `json:"results"`
+}
+
func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) {
var req ReqSearchUsers
err := json.NewDecoder(r.Body).Decode(&req)
@@ -667,12 +704,20 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
if login == nil {
return
}
- resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query)
- if err != nil {
- RespondWithError(w, err, "Internal error searching users")
+ api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
+ if !ok {
+ mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w)
return
}
- exhttp.WriteJSONResponse(w, http.StatusOK, resp)
+ resp, err := api.SearchUsers(r.Context(), req.Query)
+ if err != nil {
+ zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
+ RespondWithError(w, err, "Internal error fetching contact list")
+ return
+ }
+ exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{
+ Results: prov.processResolveIdentifiers(r.Context(), resp),
+ })
}
func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) {
@@ -684,24 +729,11 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request
}
func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) {
- var req bridgev2.GroupCreateParams
- err := json.NewDecoder(r.Body).Decode(&req)
- if err != nil {
- zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
- mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
- return
- }
- req.Type = r.PathValue("type")
login := prov.GetLoginForRequest(w, r)
if login == nil {
return
}
- resp, err := provisionutil.CreateGroup(r.Context(), login, &req)
- if err != nil {
- RespondWithError(w, err, "Internal error creating group")
- return
- }
- exhttp.WriteJSONResponse(w, http.StatusOK, resp)
+ mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w)
}
type ReqExportCredentials struct {
diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml
index 26068db4..b9879ea5 100644
--- a/bridgev2/matrix/provisioning.yaml
+++ b/bridgev2/matrix/provisioning.yaml
@@ -361,25 +361,14 @@ paths:
$ref: '#/components/responses/InternalError'
501:
$ref: '#/components/responses/NotSupported'
- /v3/create_group/{type}:
+ /v3/create_group:
post:
tags: [ snc ]
summary: Create a group chat on the remote network.
operationId: createGroup
parameters:
- $ref: "#/components/parameters/loginID"
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/GroupCreateParams'
responses:
- 200:
- description: Identifier resolved successfully
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/CreatedGroup'
401:
$ref: '#/components/responses/Unauthorized'
404:
@@ -400,7 +389,7 @@ components:
- username
- meow@example.com
loginID:
- name: login_id
+ name: loginID
in: query
description: An optional explicit login ID to do the action through.
required: false
@@ -583,74 +572,6 @@ components:
description: The Matrix room ID of the direct chat with the user.
examples:
- '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
- GroupCreateParams:
- type: object
- description: |
- Parameters for creating a group chat.
- The /capabilities endpoint response must be checked to see which fields are actually allowed.
- properties:
- type:
- type: string
- description: The type of group to create.
- examples:
- - channel
- username:
- type: string
- description: The public username for the created group.
- participants:
- type: array
- description: The users to add to the group initially.
- items:
- type: string
- parent:
- type: object
- name:
- type: object
- description: The `m.room.name` event content for the room.
- properties:
- name:
- type: string
- avatar:
- type: object
- description: The `m.room.avatar` event content for the room.
- properties:
- url:
- type: string
- format: mxc
- topic:
- type: object
- description: The `m.room.topic` event content for the room.
- properties:
- topic:
- type: string
- disappear:
- type: object
- description: The `com.beeper.disappearing_timer` event content for the room.
- properties:
- type:
- type: string
- timer:
- type: number
- room_id:
- type: string
- format: matrix_room_id
- description: |
- An existing Matrix room ID to bridge to.
- The other parameters must be already in sync with the room state when using this parameter.
- CreatedGroup:
- type: object
- description: A successfully created group chat.
- required: [id, mxid]
- properties:
- id:
- type: string
- description: The internal chat ID of the created group.
- mxid:
- type: string
- format: matrix_room_id
- description: The Matrix room ID of the portal.
- examples:
- - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
LoginStep:
type: object
description: A step in a login process.
@@ -714,7 +635,7 @@ components:
type:
type: string
description: The type of field.
- enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ]
+ enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ]
id:
type: string
description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge.
@@ -728,53 +649,10 @@ components:
description: A more detailed description of the field shown to the user.
examples:
- Include the country code with a +
- default_value:
- type: string
- description: A default value that the client can pre-fill the field with.
pattern:
type: string
format: regex
description: A regular expression that the field value must match.
- options:
- type: array
- description: For fields of type select, the valid options.
- items:
- type: string
- attachments:
- type: array
- description: A list of media attachments to show the user alongside the form fields.
- items:
- type: object
- description: A media attachment to show the user.
- required: [ type, filename, content ]
- properties:
- type:
- type: string
- description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported.
- enum: [ m.image, m.audio ]
- filename:
- type: string
- description: The filename for the media attachment.
- content:
- type: string
- description: The raw file content for the attachment encoded in base64.
- info:
- type: object
- description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted.
- properties:
- mimetype:
- type: string
- description: The MIME type for the media content.
- examples: [ image/png, audio/mpeg ]
- w:
- type: number
- description: The width of the media in pixels. Only applicable for images and videos.
- h:
- type: number
- description: The height of the media in pixels. Only applicable for images and videos.
- size:
- type: number
- description: The size of the media content in number of bytes. Strongly recommended to include.
- description: Cookie login step
required: [ type, cookies ]
properties:
diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go
index 82ea8c2b..9db5f442 100644
--- a/bridgev2/matrix/publicmedia.go
+++ b/bridgev2/matrix/publicmedia.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,26 +7,18 @@
package matrix
import (
- "context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
- "mime"
"net/http"
- "net/url"
- "slices"
- "strings"
"time"
- "github.com/rs/zerolog"
+ "github.com/gorilla/mux"
"maunium.net/go/mautrix/bridgev2"
- "maunium.net/go/mautrix/bridgev2/database"
- "maunium.net/go/mautrix/crypto/attachment"
- "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -43,10 +35,7 @@ func (br *Connector) initPublicMedia() error {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
- br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia)
- br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia)
- br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
- br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia)
+ br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
return nil
}
@@ -57,20 +46,6 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte {
return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)]
}
-func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte {
- hasher := hmac.New(sha256.New, br.pubMediaSigKey)
- hasher.Write([]byte(pm.MXC.String()))
- hasher.Write([]byte(pm.MimeType))
- if pm.Keys != nil {
- hasher.Write([]byte(pm.Keys.Version))
- hasher.Write([]byte(pm.Keys.Key.Algorithm))
- hasher.Write([]byte(pm.Keys.Key.Key))
- hasher.Write([]byte(pm.Keys.InitVector))
- hasher.Write([]byte(pm.Keys.Hashes.SHA256))
- }
- return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength]
-}
-
func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte {
var expiresAt []byte
if br.Config.PublicMedia.Expiry > 0 {
@@ -101,15 +76,16 @@ var proxyHeadersToCopy = []string{
}
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
+ vars := mux.Vars(r)
contentURI := id.ContentURI{
- Homeserver: r.PathValue("server"),
- FileID: r.PathValue("mediaID"),
+ Homeserver: vars["server"],
+ FileID: vars["mediaID"],
}
if !contentURI.IsValid() {
http.Error(w, "invalid content URI", http.StatusBadRequest)
return
}
- checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
+ checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
return
@@ -120,47 +96,9 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
http.Error(w, "checksum expired", http.StatusGone)
return
}
- br.doProxyMedia(w, r, contentURI, nil, "")
-}
-
-func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) {
- if !br.Config.PublicMedia.UseDatabase {
- http.Error(w, "public media short links are disabled", http.StatusNotFound)
- return
- }
- log := zerolog.Ctx(r.Context())
- media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID"))
- if err != nil {
- log.Err(err).Msg("Failed to get public media from database")
- http.Error(w, "failed to get media metadata", http.StatusInternalServerError)
- return
- } else if media == nil {
- http.Error(w, "media ID not found", http.StatusNotFound)
- return
- } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) {
- // This is not gone as it can still be refreshed in the DB
- http.Error(w, "media expired", http.StatusNotFound)
- return
- } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil {
- http.Error(w, "media keys are malformed", http.StatusInternalServerError)
- return
- }
- br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType)
-}
-
-var safeMimes = []string{
- "text/css", "text/plain", "text/csv",
- "application/json", "application/ld+json",
- "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
- "video/mp4", "video/webm", "video/ogg", "video/quicktime",
- "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
- "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac",
-}
-
-func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) {
resp, err := br.Bot.Download(r.Context(), contentURI)
if err != nil {
- zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
+ br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
http.Error(w, "failed to download media", http.StatusInternalServerError)
return
}
@@ -168,41 +106,11 @@ func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, conten
for _, hdr := range proxyHeadersToCopy {
w.Header()[hdr] = resp.Header[hdr]
}
- stream := resp.Body
- if encInfo != nil {
- if mimeType == "" {
- mimeType = "application/octet-stream"
- }
- contentDisposition := "attachment"
- if slices.Contains(safeMimes, mimeType) {
- contentDisposition = "inline"
- }
- dispositionArgs := map[string]string{}
- if filename := r.PathValue("filename"); filename != "" {
- dispositionArgs["filename"] = filename
- }
- w.Header().Set("Content-Type", mimeType)
- w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs))
- // Note: this won't check the Close result like it should, but it's probably not a big deal here
- stream = encInfo.DecryptStream(stream)
- } else if filename := r.PathValue("filename"); filename != "" {
- contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
- if contentDisposition == "" {
- contentDisposition = "attachment"
- }
- w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{
- "filename": filename,
- }))
- }
w.WriteHeader(http.StatusOK)
- _, _ = io.Copy(w, stream)
+ _, _ = io.Copy(w, resp.Body)
}
func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string {
- return br.getPublicMediaAddressWithFileName(contentURI, "")
-}
-
-func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string {
if br.pubMediaSigKey == nil {
return ""
}
@@ -210,69 +118,11 @@ func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIS
if err != nil || !parsed.IsValid() {
return ""
}
- fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_"))
- if fileName == ".." {
- fileName = ""
- }
- parts := []string{
+ return fmt.Sprintf(
+ "%s/_mautrix/publicmedia/%s/%s/%s",
br.GetPublicAddress(),
- strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
parsed.Homeserver,
parsed.FileID,
base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)),
- fileName,
- }
- if fileName == "" {
- parts = parts[:len(parts)-1]
- }
- return strings.Join(parts, "/")
-}
-
-func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) {
- if br.pubMediaSigKey == nil {
- return "", bridgev2.ErrPublicMediaDisabled
- }
- if !br.Config.PublicMedia.UseDatabase {
- if evt.File != nil {
- return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled)
- }
- return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil
- }
- mxc := evt.URL
- var keys *attachment.EncryptedFile
- if evt.File != nil {
- mxc = evt.File.URL
- keys = &evt.File.EncryptedFile
- }
- parsedMXC, err := mxc.Parse()
- if err != nil {
- return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
- }
- pm := &database.PublicMedia{
- MXC: parsedMXC,
- Keys: keys,
- MimeType: evt.GetInfo().MimeType,
- }
- if br.Config.PublicMedia.Expiry > 0 {
- pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second)
- }
- pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm))
- err = br.Bridge.DB.PublicMedia.Put(ctx, pm)
- if err != nil {
- return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
- }
- fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_"))
- if fileName == ".." {
- fileName = ""
- }
- parts := []string{
- br.GetPublicAddress(),
- strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
- pm.PublicID,
- fileName,
- }
- if fileName == "" {
- parts = parts[:len(parts)-1]
- }
- return strings.Join(parts, "/"), nil
+ )
}
diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go
index b498cacd..c679f960 100644
--- a/bridgev2/matrix/websocket.go
+++ b/bridgev2/matrix/websocket.go
@@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) {
addr = br.Config.Homeserver.Address
}
for {
- err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect)
+ err := br.AS.StartWebsocket(addr, onConnect)
if errors.Is(err, appservice.ErrWebsocketManualStop) {
return
} else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced {
diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go
index be26db49..ae1b99d7 100644
--- a/bridgev2/matrixinterface.go
+++ b/bridgev2/matrixinterface.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -10,11 +10,10 @@ import (
"context"
"fmt"
"io"
- "net/http"
"os"
"time"
- "go.mau.fi/util/exhttp"
+ "github.com/gorilla/mux"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
@@ -25,10 +24,8 @@ import (
)
type MatrixCapabilities struct {
- AutoJoinInvites bool
- BatchSending bool
- ArbitraryMemberChange bool
- ExtraProfileMeta bool
+ AutoJoinInvites bool
+ BatchSending bool
}
type MatrixConnector interface {
@@ -61,55 +58,32 @@ type MatrixConnector interface {
ServerName() string
}
-type MatrixConnectorWithArbitraryRoomState interface {
- MatrixConnector
- GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error)
-}
-
type MatrixConnectorWithServer interface {
- MatrixConnector
GetPublicAddress() string
- GetRouter() *http.ServeMux
-}
-
-type IProvisioningAPI interface {
- GetRouter() *http.ServeMux
- GetUser(r *http.Request) *User
-}
-
-type MatrixConnectorWithProvisioning interface {
- MatrixConnector
- GetProvisioning() IProvisioningAPI
+ GetRouter() *mux.Router
}
type MatrixConnectorWithPublicMedia interface {
- MatrixConnector
GetPublicMediaAddress(contentURI id.ContentURIString) string
- GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error)
}
type MatrixConnectorWithNameDisambiguation interface {
- MatrixConnector
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
}
type MatrixConnectorWithBridgeIdentifier interface {
- MatrixConnector
GetUniqueBridgeID() string
}
type MatrixConnectorWithURLPreviews interface {
- MatrixConnector
GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error)
}
type MatrixConnectorWithPostRoomBridgeHandling interface {
- MatrixConnector
HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error
}
type MatrixConnectorWithAnalytics interface {
- MatrixConnector
TrackAnalytics(userID id.UserID, event string, properties map[string]any)
}
@@ -124,15 +98,9 @@ type DirectNotificationData struct {
}
type MatrixConnectorWithNotifications interface {
- MatrixConnector
DisplayNotification(ctx context.Context, data *DirectNotificationData)
}
-type MatrixConnectorWithHTTPSettings interface {
- MatrixConnector
- GetHTTPClientSettings() exhttp.ClientSettings
-}
-
type MatrixSendExtra struct {
Timestamp time.Time
MessageMeta *database.Message
@@ -176,10 +144,6 @@ func (ce CallbackError) Unwrap() error {
return ce.Wrapped
}
-type EnsureJoinedParams struct {
- Via []string
-}
-
type MatrixAPI interface {
GetMXID() id.UserID
IsDoublePuppet() bool
@@ -200,26 +164,17 @@ type MatrixAPI interface {
CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error)
DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error
- EnsureJoined(ctx context.Context, roomID id.RoomID, params ...EnsureJoinedParams) error
+ EnsureJoined(ctx context.Context, roomID id.RoomID) error
EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error
TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error
MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error
-
- GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error)
}
type StreamOrderReadingMatrixAPI interface {
- MatrixAPI
MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error
}
type MarkAsDMMatrixAPI interface {
- MatrixAPI
MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error
}
-
-type EphemeralSendingMatrixAPI interface {
- MatrixAPI
- BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error)
-}
diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go
index 75c00cb0..bfbabd26 100644
--- a/bridgev2/matrixinvite.go
+++ b/bridgev2/matrixinvite.go
@@ -88,36 +88,6 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI,
rejectInvite(ctx, evt, intent, "")
}
-func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) {
- if portal.MXID == "" {
- return
- }
- log := zerolog.Ctx(ctx)
- existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID)
- if err != nil {
- log.Err(err).
- Stringer("old_portal_mxid", portal.MXID).
- Msg("Failed to check existing portal members, deleting room")
- } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok {
- log.Debug().
- Stringer("old_portal_mxid", portal.MXID).
- Msg("Inviter has no member event in old portal, deleting room")
- } else if targetUserMember.Membership.IsInviteOrJoin() {
- return
- } else {
- log.Debug().
- Stringer("old_portal_mxid", portal.MXID).
- Str("membership", string(targetUserMember.Membership)).
- Msg("Inviter is not in old portal, deleting room")
- }
-
- if err = portal.RemoveMXID(ctx); err != nil {
- log.Err(err).Msg("Failed to delete old portal mxid")
- } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
- log.Err(err).Msg("Failed to clean up old portal room")
- }
-}
-
func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult {
ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey()))
validator, ok := br.Network.(IdentifierValidatingNetwork)
@@ -195,7 +165,34 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
}
- portal.CleanupOrphanedDM(ctx, sender.MXID)
+ if portal.MXID != "" {
+ doCleanup := true
+ existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID)
+ if err != nil {
+ log.Err(err).
+ Stringer("old_portal_mxid", portal.MXID).
+ Msg("Failed to check existing portal members, deleting room")
+ } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok {
+ log.Debug().
+ Stringer("old_portal_mxid", portal.MXID).
+ Msg("Inviter has no member event in old portal, deleting room")
+ } else if targetUserMember.Membership.IsInviteOrJoin() {
+ doCleanup = false
+ } else {
+ log.Debug().
+ Stringer("old_portal_mxid", portal.MXID).
+ Str("membership", string(targetUserMember.Membership)).
+ Msg("Inviter is not in old portal, deleting room")
+ }
+
+ if doCleanup {
+ if err = portal.RemoveMXID(ctx); err != nil {
+ log.Err(err).Msg("Failed to delete old portal mxid")
+ } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
+ log.Err(err).Msg("Failed to clean up old portal room")
+ }
+ }
+ }
err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID())
if err != nil {
log.Err(err).Msg("Failed to ensure bot is invited to room")
@@ -209,67 +206,72 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
- portal.roomCreateLock.Lock()
- defer portal.roomCreateLock.Unlock()
- portalMXID := portal.MXID
- if portalMXID != "" {
- sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL())
- rejectInvite(ctx, evt, br.Bot, "")
- return EventHandlingResultSuccess
- }
- err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
- if err != nil {
- log.Err(err).Msg("Failed to give permissions to bridge bot")
- sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot")
- rejectInvite(ctx, evt, br.Bot, "")
- return EventHandlingResultSuccess
- }
- overrideIntent := invitedGhost.Intent
- if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
- log.Debug().
- Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
- Msg("Created DM was redirected to another user ID")
- _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
- Parsed: &event.MemberEventContent{
- Membership: event.MembershipLeave,
- Reason: "Direct chat redirected to another internal user ID",
+ didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID)
+ if didSetPortal {
+ message := "Private chat portal created"
+ err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
+ hasWarning := false
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to give power to bot in new DM")
+ message += "\n\nWarning: failed to promote bot"
+ hasWarning = true
+ }
+ if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
+ log.Debug().
+ Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
+ Msg("Created DM was redirected to another user ID")
+ _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
+ Parsed: &event.MemberEventContent{
+ Membership: event.MembershipLeave,
+ Reason: "Direct chat redirected to another internal user ID",
+ },
+ }, time.Time{})
+ if err != nil {
+ log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
+ }
+ otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo)
+ if err != nil {
+ log.Err(err).Msg("Failed to get ghost of real portal other user ID")
+ } else {
+ invitedGhost = otherUserGhost
+ }
+ }
+ if resp.PortalInfo != nil {
+ portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{})
+ } else {
+ portal.UpdateCapabilities(ctx, sourceLogin, true)
+ portal.UpdateBridgeInfo(ctx)
+ }
+ // TODO this might become unnecessary if UpdateInfo starts taking care of it
+ _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
+ Parsed: &event.ElementFunctionalMembersContent{
+ ServiceMembers: []id.UserID{br.Bot.GetMXID()},
},
}, time.Time{})
if err != nil {
- log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
+ log.Warn().Err(err).Msg("Failed to set service members in room")
+ if !hasWarning {
+ message += "\n\nWarning: failed to set service members"
+ hasWarning = true
+ }
}
- if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot {
- overrideIntent = br.Bot
- } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil {
- log.Err(err).Msg("Failed to get ghost of real portal other user ID")
- } else {
- invitedGhost = otherUserGhost
- overrideIntent = otherUserGhost.Intent
+ mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
+ if ok {
+ err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
+ if err != nil {
+ if hasWarning {
+ message += fmt.Sprintf(", %s", err.Error())
+ } else {
+ message += fmt.Sprintf("\n\nWarning: %s", err.Error())
+ }
+ }
}
+ sendNotice(ctx, evt, invitedGhost.Intent, message)
+ } else {
+ // TODO ensure user is invited even if PortalInfo wasn't provided?
+ sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL())
+ rejectInvite(ctx, evt, br.Bot, "")
}
- err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{
- // We locked it before checking the mxid
- RoomCreateAlreadyLocked: true,
-
- FailIfMXIDSet: true,
- ChatInfo: resp.PortalInfo,
- ChatInfoSource: sourceLogin,
- })
- if err != nil {
- log.Err(err).Msg("Failed to update Matrix room ID for new DM portal")
- sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work")
- return EventHandlingResultSuccess
- }
- message := "Private chat portal created"
- mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
- if ok {
- err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
- if err != nil {
- log.Err(err).Msg("Error in connector newly bridged room handler")
- message += fmt.Sprintf("\n\nWarning: %s", err.Error())
- }
- }
- sendNotice(ctx, evt, overrideIntent, message)
return EventHandlingResultSuccess
}
@@ -292,3 +294,21 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith
}
return nil
}
+
+func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
+ portal.roomCreateLock.Lock()
+ defer portal.roomCreateLock.Unlock()
+ if portal.MXID != "" {
+ return false
+ }
+ portal.MXID = roomID
+ portal.updateLogger()
+ portal.Bridge.cacheLock.Lock()
+ portal.Bridge.portalsByMXID[portal.MXID] = portal
+ portal.Bridge.cacheLock.Unlock()
+ err := portal.Save(ctx)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid")
+ }
+ return true
+}
diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go
index df0c9e4d..7118649d 100644
--- a/bridgev2/messagestatus.go
+++ b/bridgev2/messagestatus.go
@@ -20,7 +20,6 @@ import (
type MessageStatusEventInfo struct {
RoomID id.RoomID
- TransactionID string
SourceEventID id.EventID
NewEventID id.EventID
EventType event.Type
@@ -42,7 +41,6 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo {
return &MessageStatusEventInfo{
RoomID: evt.RoomID,
- TransactionID: evt.Unsigned.TransactionID,
SourceEventID: evt.ID,
EventType: evt.Type,
MessageType: evt.Content.AsMessage().MsgType,
@@ -184,10 +182,9 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe
Type: event.RelReference,
EventID: evt.SourceEventID,
},
- TargetTxnID: evt.TransactionID,
- Status: ms.Status,
- Reason: ms.ErrorReason,
- Message: ms.Message,
+ Status: ms.Status,
+ Reason: ms.ErrorReason,
+ Message: ms.Message,
}
if ms.InternalError != nil {
content.InternalError = ms.InternalError.Error()
diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go
index e3a6df70..443d3655 100644
--- a/bridgev2/networkid/bridgeid.go
+++ b/bridgev2/networkid/bridgeid.go
@@ -47,8 +47,8 @@ type PortalID string
// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true.
// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user.
type PortalKey struct {
- ID PortalID `json:"portal_id"`
- Receiver UserLoginID `json:"portal_receiver,omitempty"`
+ ID PortalID
+ Receiver UserLoginID
}
func (pk PortalKey) IsEmpty() bool {
diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go
index b706aedb..eb38bd2d 100644
--- a/bridgev2/networkinterface.go
+++ b/bridgev2/networkinterface.go
@@ -16,9 +16,7 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/configupgrade"
"go.mau.fi/util/ptr"
- "go.mau.fi/util/random"
- "maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
@@ -119,15 +117,11 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa
mediaPart.Content.EnsureHasHTML()
mediaPart.Content.Body += "\n\n" + textPart.Content.Body
mediaPart.Content.FormattedBody += "
" + textPart.Content.FormattedBody
- mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions)
- mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...)
} else {
mediaPart.Content.FileName = mediaPart.Content.Body
mediaPart.Content.Body = textPart.Content.Body
mediaPart.Content.Format = textPart.Content.Format
mediaPart.Content.FormattedBody = textPart.Content.FormattedBody
- mediaPart.Content.Mentions = textPart.Content.Mentions
- mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews
}
if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok {
metaMerger.CopyFrom(textPart.DBMetadata)
@@ -261,7 +255,6 @@ type NetworkConnector interface {
}
type StoppableNetwork interface {
- NetworkConnector
// Stop is called when the bridge is stopping, after all network clients have been disconnected.
Stop()
}
@@ -318,16 +311,6 @@ type MaxFileSizeingNetwork interface {
SetMaxFileSize(maxSize int64)
}
-type NetworkResettingNetwork interface {
- NetworkConnector
- // ResetHTTPTransport should recreate the HTTP client used by the bridge.
- // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable.
- ResetHTTPTransport()
- // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections.
- // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here.
- ResetNetworkConnections()
-}
-
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
type MatrixMessageResponse struct {
@@ -359,16 +342,10 @@ type NetworkGeneralCapabilities struct {
// Should the bridge re-request user info on incoming messages even if the ghost already has info?
// By default, info is only requested for ghosts with no name, and other updating is left to events.
AggressiveUpdateInfo bool
- // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message?
- // This should be enabled if the network requires each message to be marked as read independently,
- // and doesn't automatically do it when sending a message.
- ImplicitReadReceipts bool
// If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave])
// to handle asynchronous message responses, this field can be set to enable
// automatic timeout errors in case the asynchronous response never arrives.
OutgoingMessageTimeouts *OutgoingTimeoutConfig
- // Capabilities related to the provisioning API.
- Provisioning ProvisioningCapabilities
}
// NetworkAPI is an interface representing a remote network client for a single user login.
@@ -702,35 +679,6 @@ type RoomTopicHandlingNetworkAPI interface {
HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error)
}
-type DisappearTimerChangingNetworkAPI interface {
- NetworkAPI
- // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed.
- // This method should update the Disappear field of the Portal with the new timer and return true
- // if the change was successful. If the change is not successful, then the field should not be updated.
- HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error)
-}
-
-// DeleteChatHandlingNetworkAPI is an optional interface that network connectors
-// can implement to delete a chat from the remote network.
-type DeleteChatHandlingNetworkAPI interface {
- NetworkAPI
- // HandleMatrixDeleteChat is called when the user explicitly deletes a chat.
- HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error
-}
-
-// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors
-// can implement to accept message requests from the remote network.
-type MessageRequestAcceptingNetworkAPI interface {
- NetworkAPI
- // HandleMatrixAcceptMessageRequest is called when the user accepts a message request.
- HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error
-}
-
-type BeeperAIStreamHandlingNetworkAPI interface {
- NetworkAPI
- HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error
-}
-
type ResolveIdentifierResponse struct {
// Ghost is the ghost of the user that the identifier resolves to.
// This field should be set whenever possible. However, it is not required,
@@ -750,8 +698,6 @@ type ResolveIdentifierResponse struct {
Chat *CreateChatResponse
}
-var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10))
-
type CreateChatResponse struct {
PortalKey networkid.PortalKey
// Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary.
@@ -760,17 +706,6 @@ type CreateChatResponse struct {
// If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user,
// this field should have the user ID of said different user.
DMRedirectedTo networkid.UserID
-
- FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant
-}
-
-type CreateChatFailedParticipant struct {
- Reason string `json:"reason"`
- InviteEventType string `json:"invite_event_type,omitempty"`
- InviteContent *event.Content `json:"invite_content,omitempty"`
-
- UserMXID id.UserID `json:"user_mxid,omitempty"`
- DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"`
}
// IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats.
@@ -805,83 +740,7 @@ type UserSearchingNetworkAPI interface {
type GroupCreatingNetworkAPI interface {
IdentifierResolvingNetworkAPI
- CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error)
-}
-
-type PersonalFilteringCustomizingNetworkAPI interface {
- NetworkAPI
- CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom)
-}
-
-type ProvisioningCapabilities struct {
- ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"`
- GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"`
-}
-
-type ResolveIdentifierCapabilities struct {
- // Can DMs be created after resolving an identifier?
- CreateDM bool `json:"create_dm"`
- // Can users be looked up by phone number?
- LookupPhone bool `json:"lookup_phone"`
- // Can users be looked up by email address?
- LookupEmail bool `json:"lookup_email"`
- // Can users be looked up by network-specific username?
- LookupUsername bool `json:"lookup_username"`
- // Can any phone number be contacted without having to validate it via lookup first?
- AnyPhone bool `json:"any_phone"`
- // Can a contact list be retrieved from the bridge?
- ContactList bool `json:"contact_list"`
- // Can users be searched by name on the remote network?
- Search bool `json:"search"`
-}
-
-type GroupTypeCapabilities struct {
- TypeDescription string `json:"type_description"`
-
- Name GroupFieldCapability `json:"name"`
- Username GroupFieldCapability `json:"username"`
- Avatar GroupFieldCapability `json:"avatar"`
- Topic GroupFieldCapability `json:"topic"`
- Disappear GroupFieldCapability `json:"disappear"`
- Participants GroupFieldCapability `json:"participants"`
- Parent GroupFieldCapability `json:"parent"`
-}
-
-type GroupFieldCapability struct {
- // Is setting this field allowed at all in the create request?
- // Even if false, the network connector should attempt to set the metadata after group creation,
- // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room.
- Allowed bool `json:"allowed"`
- // Is setting this field mandatory for the creation to succeed?
- Required bool `json:"required,omitempty"`
- // The minimum/maximum length of the field, if applicable.
- // For members, length means the number of members excluding the creator.
- MinLength int `json:"min_length,omitempty"`
- MaxLength int `json:"max_length,omitempty"`
-
- // Only for the disappear field: allowed disappearing settings
- DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"`
-
- // This can be used to tell provisionutil not to call ValidateUserID on each participant.
- // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs.
- SkipIdentifierValidation bool `json:"-"`
-}
-
-type GroupCreateParams struct {
- Type string `json:"type,omitempty"`
-
- Username string `json:"username,omitempty"`
- // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs
- Participants []networkid.UserID `json:"participants,omitempty"`
- Parent *networkid.PortalKey `json:"parent,omitempty"`
-
- Name *event.RoomNameEventContent `json:"name,omitempty"`
- Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"`
- Topic *event.TopicEventContent `json:"topic,omitempty"`
- Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"`
-
- // An existing room ID to bridge to. If unset, a new room will be created.
- RoomID id.RoomID `json:"room_id,omitempty"`
+ CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error)
}
type MembershipChangeType struct {
@@ -921,15 +780,16 @@ type MatrixMembershipChange struct {
MatrixRoomMeta[*event.MemberEventContent]
Target GhostOrUserLogin
Type MembershipChangeType
-}
-type MatrixMembershipResult struct {
- RedirectTo networkid.UserID
+ // Deprecated: Use Target instead
+ TargetGhost *Ghost
+ // Deprecated: Use Target instead
+ TargetUserLogin *UserLogin
}
type MembershipHandlingNetworkAPI interface {
NetworkAPI
- HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error)
+ HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error)
}
type SinglePowerLevelChange struct {
@@ -1115,11 +975,6 @@ type RemoteEvent interface {
GetSender() EventSender
}
-type RemoteEventWithContextMutation interface {
- RemoteEvent
- MutateContext(ctx context.Context) context.Context
-}
-
type RemoteEventWithUncertainPortalReceiver interface {
RemoteEvent
PortalReceiverIsUncertain() bool
@@ -1173,11 +1028,6 @@ type RemoteChatDelete interface {
RemoteDeleteOnlyForMe
}
-type RemoteChatDeleteWithChildren interface {
- RemoteChatDelete
- DeleteChildren() bool
-}
-
type RemoteEventThatMayCreatePortal interface {
RemoteEvent
ShouldCreatePortal() bool
@@ -1410,14 +1260,12 @@ type MatrixMessageRemove struct {
type MatrixRoomMeta[ContentType any] struct {
MatrixEventBase[ContentType]
- PrevContent ContentType
- IsStateRequest bool
+ PrevContent ContentType
}
type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent]
type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent]
type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent]
-type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer]
type MatrixReadReceipt struct {
Portal *Portal
@@ -1432,8 +1280,6 @@ type MatrixReadReceipt struct {
LastRead time.Time
// The receipt metadata.
Receipt event.ReadReceipt
- // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt.
- Implicit bool
}
type MatrixTyping struct {
@@ -1447,9 +1293,6 @@ type MatrixViewingChat struct {
Portal *Portal
}
-type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent]
-type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent]
-type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent]
type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent]
type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent]
type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent]
diff --git a/bridgev2/portal.go b/bridgev2/portal.go
index 5ba29507..ab1f37f1 100644
--- a/bridgev2/portal.go
+++ b/bridgev2/portal.go
@@ -19,7 +19,6 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/exfmt"
- "go.mau.fi/util/exmaps"
"go.mau.fi/util/exslices"
"go.mau.fi/util/exsync"
"go.mau.fi/util/ptr"
@@ -86,15 +85,9 @@ type Portal struct {
lastCapUpdate time.Time
- roomCreateLock sync.Mutex
- cancelRoomCreate atomic.Pointer[context.CancelFunc]
- RoomCreated *exsync.Event
+ roomCreateLock sync.Mutex
- functionalMembersLock sync.Mutex
- functionalMembersCache *event.ElementFunctionalMembersContent
-
- events chan portalEvent
- deleted *exsync.Event
+ events chan portalEvent
eventsLock sync.Mutex
eventIdx int
@@ -126,15 +119,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
currentlyTypingLogins: make(map[id.UserID]*UserLogin),
currentlyTypingGhosts: exsync.NewSet[id.UserID](),
outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage),
-
- RoomCreated: exsync.NewEvent(),
- deleted: exsync.NewEvent(),
}
- if portal.MXID != "" {
- portal.RoomCreated.Set()
- }
- // Putting the portal in the cache before it's fully initialized is mildly dangerous,
- // but loading the relay user login may depend on it.
br.portalsByKey[portal.PortalKey] = portal
if portal.MXID != "" {
br.portalsByMXID[portal.MXID] = portal
@@ -143,20 +128,12 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
if portal.ParentKey.ID != "" {
portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false)
if err != nil {
- delete(br.portalsByKey, portal.PortalKey)
- if portal.MXID != "" {
- delete(br.portalsByMXID, portal.MXID)
- }
return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err)
}
}
if portal.RelayLoginID != "" {
portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID)
if err != nil {
- delete(br.portalsByKey, portal.PortalKey)
- if portal.MXID != "" {
- delete(br.portalsByMXID, portal.MXID)
- }
return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err)
}
}
@@ -169,9 +146,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
}
func (portal *Portal) updateLogger() {
- logWith := portal.Bridge.Log.With().
- Str("portal_id", string(portal.ID)).
- Str("portal_receiver", string(portal.Receiver))
+ logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID))
if portal.MXID != "" {
logWith = logWith.Stringer("portal_mxid", portal.MXID)
}
@@ -195,16 +170,6 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta
return output, nil
}
-func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) {
- if dbPortal == nil {
- return nil, nil
- } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok {
- return cached, nil
- } else {
- return br.loadPortal(ctx, dbPortal, nil, nil)
- }
-}
-
func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) {
if br.Config.SplitPortals && key.Receiver == "" {
return nil, fmt.Errorf("receiver must always be set when split portals is enabled")
@@ -294,26 +259,6 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us
return br.loadManyPortals(ctx, rows)
}
-func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) {
- br.cacheLock.Lock()
- defer br.cacheLock.Unlock()
- rows, err := br.DB.Portal.GetChildren(ctx, parent)
- if err != nil {
- return nil, err
- }
- return br.loadManyPortals(ctx, rows)
-}
-
-func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) {
- br.cacheLock.Lock()
- defer br.cacheLock.Unlock()
- dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID)
- if err != nil {
- return nil, err
- }
- return br.loadPortalWithCacheCheck(ctx, dbPortal)
-}
-
func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
@@ -339,23 +284,15 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port
}
func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult {
- if portal.deleted.IsSet() {
- return EventHandlingResultIgnored
- }
if PortalEventBuffer == 0 {
portal.eventsLock.Lock()
defer portal.eventsLock.Unlock()
portal.eventIdx++
- return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt)
+ return portal.handleSingleEventAsync(portal.eventIdx, evt)
} else {
- if portal.events == nil {
- panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey))
- }
select {
case portal.events <- evt:
return EventHandlingResultQueued
- case <-portal.deleted.GetChan():
- return EventHandlingResultIgnored
default:
zerolog.Ctx(ctx).Error().
Str("portal_id", string(portal.ID)).
@@ -380,88 +317,64 @@ func (portal *Portal) eventLoop() {
go portal.pendingMessageTimeoutLoop(ctx, cfg)
defer cancel()
}
- deleteCh := portal.deleted.GetChan()
- for i := 0; ; i++ {
- select {
- case rawEvt := <-portal.events:
- if rawEvt == nil {
- return
- }
- if portal.Bridge.Config.AsyncEvents {
- go portal.handleSingleEventWithDelayLogging(i, rawEvt)
- } else {
- portal.handleSingleEventWithDelayLogging(i, rawEvt)
- }
- case <-deleteCh:
- return
- }
+ i := 0
+ for rawEvt := range portal.events {
+ i++
+ portal.handleSingleEventAsync(i, rawEvt)
}
}
-func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) {
+func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) {
ctx := portal.getEventCtxWithLog(rawEvt, idx)
- log := zerolog.Ctx(ctx)
- doneCh := make(chan struct{})
- var backgrounded atomic.Bool
- start := time.Now()
- var handleDuration time.Duration
- // Note: this will not set the success flag if the handler times out
- outerRes = EventHandlingResult{Queued: true}
- go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {
- outerRes = res
- handleDuration = time.Since(start)
- close(doneCh)
- if backgrounded.Load() {
- log.Debug().
- Time("started_at", start).
- Stringer("duration", handleDuration).
- Msg("Event that took too long finally finished handling")
- }
- })
- tick := time.NewTicker(30 * time.Second)
- _, isCreate := rawEvt.(*portalCreateEvent)
- defer tick.Stop()
- for i := 0; i < 10; i++ {
- select {
- case <-doneCh:
- if i > 0 {
+ if _, isCreate := rawEvt.(*portalCreateEvent); isCreate {
+ portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {
+ outerRes = res
+ })
+ } else if portal.Bridge.Config.AsyncEvents {
+ outerRes = EventHandlingResultQueued
+ go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {})
+ } else {
+ log := zerolog.Ctx(ctx)
+ doneCh := make(chan struct{})
+ var backgrounded atomic.Bool
+ start := time.Now()
+ var handleDuration time.Duration
+ // Note: this will not set the success flag if the handler times out
+ outerRes = EventHandlingResult{Queued: true}
+ go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {
+ outerRes = res
+ handleDuration = time.Since(start)
+ close(doneCh)
+ if backgrounded.Load() {
log.Debug().
Time("started_at", start).
Stringer("duration", handleDuration).
- Msg("Event that took long finished handling")
+ Msg("Event that took too long finally finished handling")
}
- return
- case <-tick.C:
- log.Warn().
- Time("started_at", start).
- Msg("Event handling is taking long")
- if isCreate {
- // Never background portal creation events
- i = 1
+ })
+ tick := time.NewTicker(30 * time.Second)
+ defer tick.Stop()
+ for i := 0; i < 10; i++ {
+ select {
+ case <-doneCh:
+ if i > 0 {
+ log.Debug().
+ Time("started_at", start).
+ Stringer("duration", handleDuration).
+ Msg("Event that took long finished handling")
+ }
+ return
+ case <-tick.C:
+ log.Warn().
+ Time("started_at", start).
+ Msg("Event handling is taking long")
}
}
+ log.Warn().
+ Time("started_at", start).
+ Msg("Event handling is taking too long, continuing in background")
+ backgrounded.Store(true)
}
- log.Warn().
- Time("started_at", start).
- Msg("Event handling is taking too long, continuing in background")
- backgrounded.Store(true)
- return
-}
-
-type contextKey int
-
-const (
- contextKeyRemoteEvent contextKey = iota
- contextKeyMatrixEvent
-)
-
-func GetMatrixEventFromContext(ctx context.Context) (evt *event.Event) {
- evt, _ = ctx.Value(contextKeyMatrixEvent).(*event.Event)
- return
-}
-
-func GetRemoteEventFromContext(ctx context.Context) (evt RemoteEvent) {
- evt, _ = ctx.Value(contextKeyRemoteEvent).(RemoteEvent)
return
}
@@ -478,10 +391,6 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context {
Stringer("event_id", evt.evt.ID).
Stringer("sender", evt.sender.MXID)
}
- ctx := portal.Bridge.BackgroundCtx
- ctx = context.WithValue(ctx, contextKeyMatrixEvent, evt.evt)
- ctx = logWith.Logger().WithContext(ctx)
- return ctx
case *portalRemoteEvent:
evt.evtType = evt.evt.GetType()
logWith = portal.Log.With().Int("event_loop_index", idx).
@@ -507,23 +416,10 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context {
logWith = logWith.Int64("remote_stream_order", remoteStreamOrder)
}
}
- if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok {
- if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() {
- logWith = logWith.Time("remote_timestamp", remoteTimestamp)
- }
- }
- ctx := portal.Bridge.BackgroundCtx
- ctx = context.WithValue(ctx, contextKeyRemoteEvent, evt.evt)
- ctx = logWith.Logger().WithContext(ctx)
- if ctxMut, ok := evt.evt.(RemoteEventWithContextMutation); ok {
- ctx = ctxMut.MutateContext(ctx)
- }
- return ctx
case *portalCreateEvent:
return evt.ctx
- default:
- panic(fmt.Errorf("invalid type %T in getEventCtxWithLog", evt))
}
+ return logWith.Logger().WithContext(portal.Bridge.BackgroundCtx)
}
func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(res EventHandlingResult)) {
@@ -559,14 +455,7 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
}()
switch evt := rawEvt.(type) {
case *portalMatrixEvent:
- isStateRequest := evt.evt.Type == event.BeeperSendState
- if isStateRequest {
- if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil {
- portal.sendErrorStatus(ctx, evt.evt, err)
- return
- }
- }
- res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest)
+ res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt)
if res.SendMSS {
if res.Error != nil {
portal.sendErrorStatus(ctx, evt.evt, res.Error)
@@ -574,21 +463,6 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
portal.sendSuccessStatus(ctx, evt.evt, 0, "")
}
}
- if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil {
- portal.revertRoomMeta(ctx, evt.evt)
- }
- if isStateRequest && res.Success && !res.SkipStateEcho {
- portal.sendRoomMeta(
- ctx,
- evt.sender.DoublePuppet(ctx),
- time.UnixMilli(evt.evt.Timestamp),
- evt.evt.Type,
- evt.evt.GetStateKey(),
- evt.evt.Content.Parsed,
- false,
- evt.evt.Content.Raw,
- )
- }
case *portalRemoteEvent:
res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt)
case *portalCreateEvent:
@@ -600,44 +474,18 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
}
}
-func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
- content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent)
- if !ok {
- return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)
- }
- evt.Content = content.Content
- evt.StateKey = &content.StateKey
- evt.Type = event.Type{Type: content.Type, Class: event.StateEventType}
- _ = evt.Content.ParseRaw(evt.Type)
- mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
- if !ok {
- return fmt.Errorf("matrix connector doesn't support fetching state")
- }
- prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey())
- if err != nil && !errors.Is(err, mautrix.MNotFound) {
- return fmt.Errorf("failed to get prev event: %w", err)
- } else if prevEvt != nil {
- evt.Unsigned.PrevContent = &prevEvt.Content
- evt.Unsigned.PrevSender = prevEvt.Sender
- }
- return nil
-}
-
func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) {
if portal.Receiver != "" {
login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver)
if err != nil {
return nil, nil, err
}
- if login == nil {
- return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn)
- } else if !login.Client.IsLoggedIn() {
- return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn)
- } else if login.UserMXID != user.MXID {
+ if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() {
if allowRelay && portal.Relay != nil {
return nil, nil, nil
}
- return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID)
+ // TODO different error for this case?
+ return nil, nil, ErrNotLoggedIn
}
up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey)
return login, up, err
@@ -720,7 +568,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID,
var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"}
-func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
+func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
log := zerolog.Ctx(ctx)
if evt.Mautrix.EventSource&event.SourceEphemeral != 0 {
switch evt.Type {
@@ -728,17 +576,11 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
return portal.handleMatrixReceipts(ctx, evt)
case event.EphemeralEventTyping:
return portal.handleMatrixTyping(ctx, evt)
- case event.BeeperEphemeralEventAIStream:
- return portal.handleMatrixAIStream(ctx, sender, evt)
default:
return EventHandlingResultIgnored
}
}
- if evt.Type == event.StateTombstone {
- // Tombstones aren't bridged so they don't need a login
- return portal.handleMatrixTombstone(ctx, evt)
- }
- login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true)
+ login, _, err := portal.FindPreferredLogin(ctx, sender, true)
if err != nil {
log.Err(err).Msg("Failed to get user login to handle Matrix event")
if errors.Is(err, ErrNotLoggedIn) {
@@ -754,9 +596,6 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
}
var origSender *OrigSender
if login == nil {
- if isStateRequest {
- return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest)
- }
login = portal.Relay
origSender = &OrigSender{
User: sender,
@@ -800,21 +639,6 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
}
// Copy logger because many of the handlers will use UpdateContext
ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx)
-
- if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() {
- rrLog := log.With().Str("subaction", "implicit read receipt").Logger()
- rrCtx := rrLog.WithContext(ctx)
- rrLog.Debug().Msg("Sending implicit read receipt for event")
- evtTS := time.UnixMilli(evt.Timestamp)
- portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{
- Portal: portal,
- EventID: evt.ID,
- Implicit: true,
- ReadUpTo: evtTS,
- Receipt: event.ReadReceipt{Timestamp: evtTS},
- }, userPortal)
- }
-
switch evt.Type {
case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse:
return portal.handleMatrixMessage(ctx, login, origSender, evt)
@@ -827,13 +651,11 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
case event.EventRedaction:
return portal.handleMatrixRedaction(ctx, login, origSender, evt)
case event.StateRoomName:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName)
case event.StateTopic:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic)
case event.StateRoomAvatar:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar)
- case event.StateBeeperDisappearingTimer:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar)
case event.StateEncryption:
// TODO?
return EventHandlingResultIgnored
@@ -844,13 +666,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
case event.AccountDataBeeperMute:
return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute)
case event.StateMember:
- return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest)
+ return portal.handleMatrixMembership(ctx, login, origSender, evt)
case event.StatePowerLevels:
- return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest)
- case event.BeeperDeleteChat:
- return portal.handleMatrixDeleteChat(ctx, login, origSender, evt)
- case event.BeeperAcceptMessageRequest:
- return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt)
+ return portal.handleMatrixPowerLevels(ctx, login, origSender, evt)
default:
return EventHandlingResultIgnored
}
@@ -870,7 +688,7 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event
sender, err := portal.Bridge.GetUserByMXID(ctx, userID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt)
}
@@ -908,10 +726,15 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e
EventID: eventID,
Receipt: receipt,
}
+ if userPortal == nil {
+ userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey)
+ } else {
+ evt.LastRead = userPortal.LastRead
+ userPortal = userPortal.CopyWithoutValues()
+ }
evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID)
if err != nil {
log.Err(err).Msg("Failed to get exact message from database")
- evt.ReadUpTo = receipt.Timestamp
} else if evt.ExactMessage != nil {
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp)
@@ -920,40 +743,21 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e
} else {
evt.ReadUpTo = receipt.Timestamp
}
- portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal)
-}
-
-func (portal *Portal) callReadReceiptHandler(
- ctx context.Context,
- login *UserLogin,
- rrClient ReadReceiptHandlingNetworkAPI,
- evt *MatrixReadReceipt,
- userPortal *database.UserPortal,
-) {
- if rrClient == nil {
- var ok bool
- rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI)
- if !ok {
- return
- }
- }
- if userPortal == nil {
- userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey)
- } else {
- evt.LastRead = userPortal.LastRead
- userPortal = userPortal.CopyWithoutValues()
- }
- err := rrClient.HandleMatrixReadReceipt(ctx, evt)
+ err = rrClient.HandleMatrixReadReceipt(ctx, evt)
if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt")
+ log.Err(err).Msg("Failed to handle read receipt")
return
}
- userPortal.LastRead = evt.ReadUpTo
+ if evt.ExactMessage != nil {
+ userPortal.LastRead = evt.ExactMessage.Timestamp
+ } else {
+ userPortal.LastRead = receipt.Timestamp
+ }
err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal)
if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata")
+ log.Err(err).Msg("Failed to save user portal metadata")
}
- portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo)
+ portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID)
}
func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -974,50 +778,6 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event)
return EventHandlingResultSuccess
}
-func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
- log := zerolog.Ctx(ctx)
- if sender == nil {
- log.Error().Msg("Missing sender for Matrix AI stream event")
- return EventHandlingResultIgnored
- }
- login, _, err := portal.FindPreferredLogin(ctx, sender, true)
- if err != nil {
- log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- var origSender *OrigSender
- if login == nil {
- if portal.Relay == nil {
- return EventHandlingResultIgnored
- }
- login = portal.Relay
- origSender = &OrigSender{
- User: sender,
- UserID: sender.MXID,
- }
- }
- content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent)
- if !ok {
- log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
- return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
- }
- api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI)
- if !ok {
- return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported)
- }
- err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{
- Event: evt,
- Content: content,
- Portal: portal,
- OrigSender: origSender,
- })
- if err != nil {
- log.Err(err).Msg("Failed to handle Matrix AI stream event")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- return EventHandlingResultSuccess.WithMSS()
-}
-
func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) {
for _, userID := range userIDs {
login, ok := portal.currentlyTypingLogins[userID]
@@ -1126,18 +886,8 @@ func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content
feat.Caption.Reject() {
return ErrCaptionsNotAllowed
}
- if content.Info != nil {
- dur := time.Duration(content.Info.Duration) * time.Millisecond
- if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration {
- if capMsgType == event.CapMsgVoice {
- return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration))
- }
- return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration))
- }
- if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize {
- return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024)
- }
- if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() {
+ if content.Info != nil && content.Info.MimeType != "" {
+ if feat.GetMimeSupport(content.Info.MimeType).Reject() {
return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType)
}
}
@@ -1197,12 +947,10 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
log.Debug().Msg("Ignoring poll event from relayed user")
return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser)
}
- if !caps.PerMessageProfileRelay {
- msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender)
- if err != nil {
- log.Err(err).Msg("Failed to format message for relaying")
- return EventHandlingResultFailed.WithMSSError(err)
- }
+ msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender)
+ if err != nil {
+ log.Err(err).Msg("Failed to format message for relaying")
+ return EventHandlingResultFailed.WithMSSError(err)
}
}
if msgContent != nil {
@@ -1270,16 +1018,6 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
}
- var messageTimer *event.BeeperDisappearingTimer
- if msgContent != nil {
- messageTimer = msgContent.BeeperDisappearingTimer
- }
- if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer {
- log.Warn().
- Any("event_timer", messageTimer).
- Any("portal_timer", portal.Disappear.ToEventContent()).
- Msg("Mismatching disappearing timer in event")
- }
wrappedMsgEvt := &MatrixMessage{
MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{
@@ -1305,12 +1043,6 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
- err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps)
- if err != nil {
- log.Warn().Err(err).Msg("Failed to auto-accept message request on message")
- // TODO stop processing?
- }
-
var resp *MatrixMessageResponse
if msgContent != nil {
resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt)
@@ -1362,23 +1094,22 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID)
}
- ds := portal.Disappear
- if messageTimer != nil {
- ds = database.DisappearingSettingFromEvent(messageTimer)
- }
- if ds.Type != event.DisappearingTypeNone {
+ if portal.Disappear.Type != database.DisappearingTypeNone {
go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{
- RoomID: portal.MXID,
- EventID: message.MXID,
- Timestamp: message.Timestamp,
- DisappearingSetting: ds.StartingAt(message.Timestamp),
+ RoomID: portal.MXID,
+ EventID: message.MXID,
+ DisappearingSetting: database.DisappearingSetting{
+ Type: portal.Disappear.Type,
+ Timer: portal.Disappear.Timer,
+ DisappearAt: message.Timestamp.Add(portal.Disappear.Timer),
+ },
})
}
if resp.Pending {
// Not exactly queued, but not finished either
return EventHandlingResultQueued
}
- return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder)
+ return EventHandlingResultSuccess
}
// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message.
@@ -1567,7 +1298,7 @@ func (portal *Portal) handleMatrixEdit(
return EventHandlingResultSuccess
}
-func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) {
+func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult {
log := zerolog.Ctx(ctx)
reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI)
if !ok {
@@ -1590,12 +1321,6 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
log.Warn().Msg("Reaction target message not found in database")
return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound))
}
- caps := sender.Client.GetCapabilities(ctx, portal)
- err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps)
- if err != nil {
- log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction")
- // TODO stop processing?
- }
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("reaction_target_remote_id", string(reactionTarget.ID))
})
@@ -1618,31 +1343,6 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if portal.Bridge.Config.OutgoingMessageReID {
deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID)
}
- defer func() {
- // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction
- if handleRes.Success {
- portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
- }
- }()
- removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) {
- if !handleRes.Success {
- return
- }
- _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
- Parsed: &event.RedactionEventContent{
- Redacts: oldReact.MXID,
- },
- }, nil)
- if err != nil {
- log.Err(err).Msg("Failed to remove old reaction")
- }
- if deleteDB {
- err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact)
- if err != nil {
- log.Err(err).Msg("Failed to delete old reaction from database")
- }
- }
- }
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
@@ -1651,10 +1351,17 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if existing.EmojiID != "" || existing.Emoji == preResp.Emoji {
log.Debug().Msg("Ignoring duplicate reaction")
portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
- return EventHandlingResultIgnored.WithEventID(deterministicID)
+ return EventHandlingResultIgnored
}
react.ReactionToOverride = existing
- defer removeOutdatedReaction(existing, false)
+ _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
+ Parsed: &event.RedactionEventContent{
+ Redacts: existing.MXID,
+ },
+ }, nil)
+ if err != nil {
+ log.Err(err).Msg("Failed to remove old reaction")
+ }
}
react.PreHandleResp = &preResp
if preResp.MaxReactions > 0 {
@@ -1669,14 +1376,18 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
// Keep n-1 previous reactions and remove the rest
react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1]
for _, oldReaction := range allReactions[preResp.MaxReactions-1:] {
- if existing != nil && oldReaction.EmojiID == existing.EmojiID {
- // Don't double-delete on networks that only allow one emoji
- continue
+ _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
+ Parsed: &event.RedactionEventContent{
+ Redacts: oldReaction.MXID,
+ },
+ }, nil)
+ if err != nil {
+ log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded")
+ }
+ err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction)
+ if err != nil {
+ log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded")
}
- // Intentionally defer in a loop, there won't be that many items,
- // and we want all of them to be done after this function completes successfully
- //goland:noinspection GoDeferInLoop
- defer removeOutdatedReaction(oldReaction, true)
}
}
}
@@ -1721,7 +1432,8 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if err != nil {
log.Err(err).Msg("Failed to save reaction to database")
}
- return EventHandlingResultSuccess.WithEventID(deterministicID)
+ portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
+ return EventHandlingResultSuccess
}
func handleMatrixRoomMeta[APIType any, ContentType any](
@@ -1730,19 +1442,11 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
- isStateRequest bool,
fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error),
) EventHandlingResult {
- if evt.StateKey == nil || *evt.StateKey != "" {
- return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
- }
- //caps := sender.Client.GetCapabilities(ctx, portal)
- //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported {
- // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed))
- //}
api, ok := sender.Client.(APIType)
if !ok {
- return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type))
+ return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported)
}
log := zerolog.Ctx(ctx)
content, ok := evt.Content.Parsed.(ContentType)
@@ -1766,18 +1470,6 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
portal.sendSuccessStatus(ctx, evt, 0, "")
return EventHandlingResultIgnored
}
- case *event.BeeperDisappearingTimer:
- if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 {
- typedContent.Type = event.DisappearingTypeNone
- typedContent.Timer.Duration = 0
- }
- if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer {
- portal.sendSuccessStatus(ctx, evt, 0, "")
- return EventHandlingResultIgnored
- }
- if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) {
- return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported)
- }
}
var prevContent ContentType
if evt.Unsigned.PrevContent != nil {
@@ -1794,17 +1486,14 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- IsStateRequest: isStateRequest,
- PrevContent: prevContent,
+ PrevContent: prevContent,
})
if err != nil {
log.Err(err).Msg("Failed to handle Matrix room metadata")
return EventHandlingResultFailed.WithMSSError(err)
}
if changed {
- if evt.Type != event.StateBeeperDisappearingTimer {
- portal.UpdateBridgeInfo(ctx)
- }
+ portal.UpdateBridgeInfo(ctx)
err = portal.Save(ctx)
if err != nil {
log.Err(err).Msg("Failed to save portal after updating room metadata")
@@ -1865,139 +1554,12 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos
}
}
-func (portal *Portal) handleMatrixAcceptMessageRequest(
- ctx context.Context,
- sender *UserLogin,
- origSender *OrigSender,
- evt *event.Event,
-) EventHandlingResult {
- if origSender != nil {
- return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser)
- }
- log := zerolog.Ctx(ctx)
- content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent)
- if !ok {
- log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
- return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
- }
- api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI)
- if !ok {
- return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported)
- }
- err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{
- Event: evt,
- Content: content,
- Portal: portal,
- })
- if err != nil {
- log.Err(err).Msg("Failed to handle Matrix accept message request")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- if portal.MessageRequest {
- portal.MessageRequest = false
- portal.UpdateBridgeInfo(ctx)
- err = portal.Save(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to save portal after accepting message request")
- }
- }
- return EventHandlingResultSuccess.WithMSS()
-}
-
-func (portal *Portal) autoAcceptMessageRequest(
- ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures,
-) error {
- if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported {
- return nil
- }
- mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI)
- if !ok {
- return nil
- }
- err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{
- Event: evt,
- Content: &event.BeeperAcceptMessageRequestEventContent{
- IsImplicit: true,
- },
- Portal: portal,
- OrigSender: origSender,
- })
- if err != nil {
- return err
- }
- if portal.MessageRequest {
- portal.MessageRequest = false
- portal.UpdateBridgeInfo(ctx)
- err = portal.Save(ctx)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request")
- }
- }
- return nil
-}
-
-func (portal *Portal) handleMatrixDeleteChat(
- ctx context.Context,
- sender *UserLogin,
- origSender *OrigSender,
- evt *event.Event,
-) EventHandlingResult {
- if origSender != nil {
- return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser)
- }
- log := zerolog.Ctx(ctx)
- content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent)
- if !ok {
- log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
- return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
- }
- api, ok := sender.Client.(DeleteChatHandlingNetworkAPI)
- if !ok {
- return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported)
- }
- err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{
- Event: evt,
- Content: content,
- Portal: portal,
- })
- if err != nil {
- log.Err(err).Msg("Failed to handle Matrix chat delete")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- if portal.Receiver == "" {
- _, others, err := portal.findOtherLogins(ctx, sender)
- if err != nil {
- log.Err(err).Msg("Failed to check if portal has other logins")
- return EventHandlingResultFailed.WithError(err)
- } else if len(others) > 0 {
- log.Debug().Msg("Not deleting portal after chat delete as other logins are present")
- return EventHandlingResultSuccess
- }
- }
- err = portal.Delete(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to delete portal from database")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false)
- if err != nil {
- log.Err(err).Msg("Failed to delete Matrix room")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- // No MSS here as the portal was deleted
- return EventHandlingResultSuccess
-}
-
func (portal *Portal) handleMatrixMembership(
ctx context.Context,
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
- isStateRequest bool,
) EventHandlingResult {
- if evt.StateKey == nil {
- return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
- }
log := zerolog.Ctx(ctx)
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
if !ok {
@@ -2033,6 +1595,7 @@ func (portal *Portal) handleMatrixMembership(
return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent)
}
targetGhost, _ := target.(*Ghost)
+ targetUserLogin, _ := target.(*UserLogin)
membershipChange := &MatrixMembershipChange{
MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{
MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{
@@ -2043,60 +1606,19 @@ func (portal *Portal) handleMatrixMembership(
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- IsStateRequest: isStateRequest,
- PrevContent: prevContent,
+ PrevContent: prevContent,
},
- Target: target,
- Type: membershipChangeType,
+ Target: target,
+ TargetGhost: targetGhost,
+ TargetUserLogin: targetUserLogin,
+ Type: membershipChangeType,
}
- res, err := api.HandleMatrixMembership(ctx, membershipChange)
+ _, err = api.HandleMatrixMembership(ctx, membershipChange)
if err != nil {
log.Err(err).Msg("Failed to handle Matrix membership change")
return EventHandlingResultFailed.WithMSSError(err)
}
- didRedirectInvite := membershipChangeType == Invite &&
- targetGhost != nil &&
- res != nil &&
- res.RedirectTo != "" &&
- res.RedirectTo != targetGhost.ID
- if didRedirectInvite {
- log.Debug().
- Str("orig_id", string(targetGhost.ID)).
- Str("redirect_id", string(res.RedirectTo)).
- Msg("Invite was redirected to different ghost")
- var redirectGhost *Ghost
- redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo)
- if err != nil {
- log.Err(err).Msg("Failed to get redirect target ghost")
- return EventHandlingResultFailed.WithError(err)
- }
- if !isStateRequest {
- portal.sendRoomMeta(
- ctx,
- sender.User.DoublePuppet(ctx),
- time.UnixMilli(evt.Timestamp),
- event.StateMember,
- evt.GetStateKey(),
- &event.MemberEventContent{
- Membership: event.MembershipLeave,
- Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo),
- },
- true,
- nil,
- )
- }
- portal.sendRoomMeta(
- ctx,
- sender.User.DoublePuppet(ctx),
- time.UnixMilli(evt.Timestamp),
- event.StateMember,
- redirectGhost.Intent.GetMXID().String(),
- content,
- false,
- nil,
- )
- }
- return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite)
+ return EventHandlingResultSuccess.WithMSS()
}
func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange {
@@ -2121,27 +1643,13 @@ func (portal *Portal) handleMatrixPowerLevels(
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
- isStateRequest bool,
) EventHandlingResult {
- if evt.StateKey == nil || *evt.StateKey != "" {
- return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
- }
log := zerolog.Ctx(ctx)
content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
if !ok {
log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
}
- if content.CreateEvent == nil {
- ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
- if ok {
- var err error
- content.CreateEvent, err = ars.GetStateEvent(ctx, portal.MXID, event.StateCreate, "")
- if err != nil {
- return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err))
- }
- }
- }
api, ok := sender.Client.(PowerLevelHandlingNetworkAPI)
if !ok {
return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported)
@@ -2150,7 +1658,6 @@ func (portal *Portal) handleMatrixPowerLevels(
if evt.Unsigned.PrevContent != nil {
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent)
- prevContent.CreateEvent = content.CreateEvent
}
plChange := &MatrixPowerLevelChange{
@@ -2163,8 +1670,7 @@ func (portal *Portal) handleMatrixPowerLevels(
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- IsStateRequest: isStateRequest,
- PrevContent: prevContent,
+ PrevContent: prevContent,
},
Users: make(map[id.UserID]*UserPowerLevelChange),
Events: make(map[string]*SinglePowerLevelChange),
@@ -2210,256 +1716,6 @@ func (portal *Portal) handleMatrixPowerLevels(
return EventHandlingResultSuccess.WithMSS()
}
-func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
- if evt.StateKey == nil || *evt.StateKey != "" || portal.MXID != evt.RoomID {
- return EventHandlingResultIgnored
- }
- log := *zerolog.Ctx(ctx)
- sentByBridge := evt.Sender == portal.Bridge.Bot.GetMXID() || portal.Bridge.IsGhostMXID(evt.Sender)
- var senderUser *User
- var err error
- if !sentByBridge {
- senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender)
- if err != nil {
- log.Err(err).Msg("Failed to get tombstone sender user")
- return EventHandlingResultFailed.WithError(err)
- }
- }
- content, ok := evt.Content.Parsed.(*event.TombstoneEventContent)
- if !ok {
- log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
- return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
- }
- log = log.With().
- Stringer("replacement_room", content.ReplacementRoom).
- Logger()
- if content.ReplacementRoom == "" {
- log.Info().Msg("Received tombstone with no replacement room, cleaning up portal")
- err := portal.RemoveMXID(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to remove portal MXID")
- return EventHandlingResultFailed.WithMSSError(err)
- }
- err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true)
- if err != nil {
- log.Err(err).Msg("Failed to clean up Matrix room")
- return EventHandlingResultFailed.WithError(err)
- }
- return EventHandlingResultSuccess
- }
- existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID())
- if err != nil {
- log.Err(err).Msg("Failed to get member info of bot in replacement room")
- return EventHandlingResultFailed.WithError(err)
- }
- leaveOnError := func() {
- if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin {
- return
- }
- log.Debug().Msg("Leaving replacement room with bot after tombstone validation failed")
- _, err = portal.Bridge.Bot.SendState(
- ctx,
- content.ReplacementRoom,
- event.StateMember,
- portal.Bridge.Bot.GetMXID().String(),
- &event.Content{
- Parsed: &event.MemberEventContent{
- Membership: event.MembershipLeave,
- Reason: fmt.Sprintf("Failed to validate tombstone sent by %s from %s", evt.Sender, evt.RoomID),
- },
- },
- time.Time{},
- )
- if err != nil {
- log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed")
- }
- }
- var via []string
- if senderHS := evt.Sender.Homeserver(); senderHS != "" {
- via = []string{senderHS}
- }
- err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via})
- if err != nil {
- log.Err(err).Msg("Failed to join replacement room from tombstone")
- return EventHandlingResultFailed.WithError(err)
- }
- if !sentByBridge && !senderUser.Permissions.Admin {
- powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom)
- if err != nil {
- log.Err(err).Msg("Failed to get power levels in replacement room")
- leaveOnError()
- return EventHandlingResultFailed.WithError(err)
- }
- if powers.GetUserLevel(evt.Sender) < powers.Invite() {
- log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room")
- leaveOnError()
- return EventHandlingResultIgnored
- }
- }
- err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{
- DeleteOldRoom: true,
- FetchInfoVia: senderUser,
- })
- if errors.Is(err, ErrTargetRoomIsPortal) {
- return EventHandlingResultIgnored
- } else if err != nil {
- return EventHandlingResultFailed.WithError(err)
- }
- return EventHandlingResultSuccess
-}
-
-var ErrTargetRoomIsPortal = errors.New("target room is already a portal")
-var ErrRoomAlreadyExists = errors.New("this portal already has a room")
-
-type UpdateMatrixRoomIDParams struct {
- SyncDBMetadata func()
- FailIfMXIDSet bool
- OverwriteOldPortal bool
- TombstoneOldRoom bool
- DeleteOldRoom bool
-
- RoomCreateAlreadyLocked bool
-
- FetchInfoVia *User
- ChatInfo *ChatInfo
- ChatInfoSource *UserLogin
-}
-
-func (portal *Portal) UpdateMatrixRoomID(
- ctx context.Context,
- newRoomID id.RoomID,
- params UpdateMatrixRoomIDParams,
-) error {
- if !params.RoomCreateAlreadyLocked {
- portal.roomCreateLock.Lock()
- defer portal.roomCreateLock.Unlock()
- }
- oldRoom := portal.MXID
- if oldRoom == newRoomID {
- return nil
- } else if oldRoom != "" && params.FailIfMXIDSet {
- return ErrRoomAlreadyExists
- }
- log := zerolog.Ctx(ctx)
- portal.Bridge.cacheLock.Lock()
- // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns
- // and unlock it before return if nothing goes wrong.
- unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock)
- defer unlockCacheLock()
- if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal {
- log.Warn().Msg("Replacement room is already a portal, ignoring")
- return ErrTargetRoomIsPortal
- } else if alreadyExists {
- log.Debug().Msg("Replacement room is already a portal, overwriting")
- existingPortal.MXID = ""
- existingPortal.RoomCreated.Clear()
- err := existingPortal.Save(ctx)
- if err != nil {
- return fmt.Errorf("failed to clear mxid of existing portal: %w", err)
- }
- delete(portal.Bridge.portalsByMXID, portal.MXID)
- }
- portal.MXID = newRoomID
- portal.RoomCreated.Set()
- portal.Bridge.portalsByMXID[portal.MXID] = portal
- portal.NameSet = false
- portal.AvatarSet = false
- portal.TopicSet = false
- portal.InSpace = false
- portal.CapState = database.CapabilityState{}
- portal.lastCapUpdate = time.Time{}
- if params.SyncDBMetadata != nil {
- params.SyncDBMetadata()
- }
- unlockCacheLock()
- portal.updateLogger()
-
- err := portal.Save(ctx)
- if err != nil {
- log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID")
- return err
- }
- log.Info().Msg("Successfully followed tombstone and updated portal MXID")
- err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID")
- }
- go portal.addToUserSpaces(ctx)
- if params.FetchInfoVia != nil {
- go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia)
- } else if params.ChatInfo != nil {
- go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{})
- } else if params.ChatInfoSource != nil {
- portal.UpdateCapabilities(ctx, params.ChatInfoSource, true)
- portal.UpdateBridgeInfo(ctx)
- }
- go func() {
- // TODO this might become unnecessary if UpdateInfo starts taking care of it
- _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
- Parsed: &event.ElementFunctionalMembersContent{
- ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()},
- },
- }, time.Time{})
- if err != nil {
- if err != nil {
- log.Warn().Err(err).Msg("Failed to set service members in new room")
- }
- }
- }()
- if params.TombstoneOldRoom && oldRoom != "" {
- _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{
- Parsed: &event.TombstoneEventContent{
- Body: "Room has been replaced.",
- ReplacementRoom: newRoomID,
- },
- }, time.Now())
- if err != nil {
- log.Err(err).Msg("Failed to send tombstone event to old room")
- }
- }
- if params.DeleteOldRoom && oldRoom != "" {
- go func() {
- err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true)
- if err != nil {
- log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID")
- }
- }()
- }
- return nil
-}
-
-func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) {
- log := zerolog.Ctx(ctx)
- logins, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to get user logins in portal to sync info")
- return
- }
- var preferredLogin *UserLogin
- for _, login := range logins {
- if !login.Client.IsLoggedIn() {
- continue
- } else if preferredLogin == nil {
- preferredLogin = login
- } else if senderUser != nil && login.User == senderUser {
- preferredLogin = login
- }
- }
- if preferredLogin == nil {
- log.Warn().Msg("No logins found to sync info")
- return
- }
- info, err := preferredLogin.Client.GetChatInfo(ctx, portal)
- if err != nil {
- log.Err(err).Msg("Failed to get chat info")
- return
- }
- log.Info().
- Str("info_source_login", string(preferredLogin.ID)).
- Msg("Fetched info to update portal after tombstone")
- portal.UpdateInfo(ctx, info, preferredLogin, nil, time.Time{})
-}
-
func (portal *Portal) handleMatrixRedaction(
ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event,
) EventHandlingResult {
@@ -2559,7 +1815,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
err = portal.createMatrixRoomInLoop(ctx, source, info, bundle)
if err != nil {
log.Err(err).Msg("Failed to create portal to handle event")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
if evtType == RemoteEventChatResync {
log.Debug().Msg("Not handling chat resync event further as portal was created by it")
@@ -2578,7 +1834,6 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
switch evtType {
case RemoteEventUnknown:
log.Debug().Msg("Ignoring remote event with type unknown")
- res = EventHandlingResultIgnored
case RemoteEventMessage, RemoteEventMessageUpsert:
res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage))
case RemoteEventEdit:
@@ -2617,46 +1872,6 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
return
}
-func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) {
- if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" {
- return
- }
- ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
- if !ok {
- return
- }
- portal.functionalMembersLock.Lock()
- defer portal.functionalMembersLock.Unlock()
- var functionalMembers *event.ElementFunctionalMembersContent
- if portal.functionalMembersCache != nil {
- functionalMembers = portal.functionalMembersCache
- } else {
- evt, err := ars.GetStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "")
- if err != nil && !errors.Is(err, mautrix.MNotFound) {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get functional members state event")
- return
- }
- functionalMembers = &event.ElementFunctionalMembersContent{}
- if evt != nil {
- evtContent, ok := evt.Content.Parsed.(*event.ElementFunctionalMembersContent)
- if ok && evtContent != nil {
- functionalMembers = evtContent
- }
- }
- }
- // TODO what about non-double-puppeted user ghosts?
- functionalMembers.Add(portal.Bridge.Bot.GetMXID())
- if functionalMembers.Add(ghost.Intent.GetMXID()) {
- _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
- Parsed: functionalMembers,
- }, time.Time{})
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to update functional members state event")
- return
- }
- }
-}
-
func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) {
var ghost *Ghost
if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID {
@@ -2680,7 +1895,6 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS
return
}
ghost.UpdateInfoIfNecessary(ctx, source, evtType)
- portal.ensureFunctionalMember(ctx, ghost)
}
if sender.IsFromMe {
intent = source.User.DoublePuppet(ctx)
@@ -2792,7 +2006,7 @@ func (portal *Portal) getRelationMeta(
log.Err(err).Msg("Failed to get last thread message from database")
}
if prevThreadEvent == nil {
- prevThreadEvent = ptr.Clone(threadRoot)
+ prevThreadEvent = threadRoot
}
}
return
@@ -2851,7 +2065,6 @@ func (portal *Portal) sendConvertedMessage(
allSuccess := true
for i, part := range converted.Parts {
portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent)
- part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent()
dbMessage := &database.Message{
ID: id,
PartID: part.ID,
@@ -2896,14 +2109,13 @@ func (portal *Portal) sendConvertedMessage(
logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database")
allSuccess = false
}
- if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() {
- if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() {
+ if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() {
+ if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() {
converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer)
}
portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: dbMessage.MXID,
- Timestamp: dbMessage.Timestamp,
DisappearingSetting: converted.Disappear,
})
}
@@ -2992,7 +2204,7 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin,
err = portal.Bridge.DB.Message.Update(ctx, part)
if err != nil {
log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database")
- handleRes = EventHandlingResultFailed.WithError(err)
+ handleRes = EventHandlingResultFailed
}
}
}
@@ -3056,7 +2268,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin
} else {
log.Err(err).Msg("Failed to convert remote message")
portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
}
_, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil)
@@ -3101,7 +2313,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e
existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID)
if err != nil {
log.Err(err).Msg("Failed to get edit target message")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
}
if existing == nil {
@@ -3126,7 +2338,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e
} else if err != nil {
log.Err(err).Msg("Failed to convert remote edit")
portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt))
if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) {
@@ -3275,7 +2487,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User
targetMessage, err := portal.getTargetMessagePart(ctx, evt)
if err != nil {
log.Err(err).Msg("Failed to get target message for reaction")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if targetMessage == nil {
// TODO use deterministic event ID as target if applicable?
log.Warn().Msg("Target message for reaction not found")
@@ -3289,7 +2501,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User
}
if err != nil {
log.Err(err).Msg("Failed to get existing reactions for reaction sync")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction)
for _, existingReaction := range existingReactions {
@@ -3411,7 +2623,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi
targetMessage, err := portal.getTargetMessagePart(ctx, evt)
if err != nil {
log.Err(err).Msg("Failed to get target message for reaction")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if targetMessage == nil {
// TODO use deterministic event ID as target if applicable?
log.Warn().Msg("Target message for reaction not found")
@@ -3421,7 +2633,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi
existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) {
log.Debug().Msg("Ignoring duplicate reaction")
return EventHandlingResultIgnored
@@ -3491,7 +2703,7 @@ func (portal *Portal) sendConvertedReaction(
})
if err != nil {
logContext(log.Err(err)).Msg("Failed to send reaction to Matrix")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
logContext(log.Debug()).
Stringer("event_id", resp.EventID).
@@ -3500,7 +2712,7 @@ func (portal *Portal) sendConvertedReaction(
err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction)
if err != nil {
logContext(log.Err(err)).Msg("Failed to save reaction to database")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
return EventHandlingResultSuccess
}
@@ -3526,7 +2738,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us
targetReaction, err := portal.getTargetReaction(ctx, evt)
if err != nil {
log.Err(err).Msg("Failed to get target reaction for removal")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if targetReaction == nil {
log.Warn().Msg("Target reaction not found")
return EventHandlingResultIgnored
@@ -3550,7 +2762,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us
}, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction})
if err != nil {
log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction)
if err != nil {
@@ -3564,7 +2776,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use
targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage())
if err != nil {
log.Err(err).Msg("Failed to get target message for removal")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if len(targetParts) == 0 {
log.Debug().Msg("Target message not found")
return EventHandlingResultIgnored
@@ -3572,14 +2784,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use
onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe)
onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe()
if onlyForMe && portal.Receiver == "" {
- _, others, err := portal.findOtherLogins(ctx, source)
- if err != nil {
- log.Err(err).Msg("Failed to check if portal has other logins")
- return EventHandlingResultFailed.WithError(err)
- } else if len(others) > 0 {
- log.Debug().Msg("Ignoring delete for me event in portal with multiple logins")
- return EventHandlingResultIgnored
- }
+ // TODO check if there are other user logins before deleting
}
intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove)
@@ -3641,7 +2846,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
if err != nil {
log.Err(err).Str("last_target_id", string(lastTargetID)).
Msg("Failed to get last target message for read receipt")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if lastTarget == nil {
log.Debug().Str("last_target_id", string(lastTargetID)).
Msg("Last target message not found")
@@ -3660,7 +2865,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
if err != nil {
log.Err(err).Str("target_id", string(targetID)).
Msg("Failed to get target message for read receipt")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) {
lastTarget = target
}
@@ -3690,24 +2895,20 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL
return evt.Int64("target_stream_order", targetStreamOrder)
}
err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt))
- if readUpTo.IsZero() {
- readUpTo = getEventTS(evt)
- }
} else {
addTargetLog = func(evt *zerolog.Event) *zerolog.Event {
return evt.Stringer("target_mxid", lastTarget.MXID)
}
err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt))
- readUpTo = lastTarget.Timestamp
}
if err != nil {
addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else {
addTargetLog(log.Debug()).Msg("Bridged read receipt")
}
if sender.IsFromMe {
- portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo)
+ portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID)
}
return EventHandlingResultSuccess
}
@@ -3724,13 +2925,13 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo
err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread())
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
return EventHandlingResultSuccess
}
func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult {
- if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") {
+ if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID {
return EventHandlingResultIgnored
}
intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt)
@@ -3742,7 +2943,7 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U
targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target)
if err != nil {
log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else if len(targetParts) == 0 {
continue
} else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost {
@@ -3777,7 +2978,7 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin,
err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
if timeout == 0 {
portal.currentlyTypingGhosts.Remove(intent.GetMXID())
@@ -3791,7 +2992,7 @@ func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *Us
info, err := evt.GetChatInfoChange(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt))
return EventHandlingResultSuccess
@@ -3829,43 +3030,22 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo
return EventHandlingResultSuccess
}
-func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) {
- others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
- if err != nil {
- return
- }
- others = slices.DeleteFunc(others, func(up *database.UserPortal) bool {
- if up.LoginID == source.ID {
- ownUP = up
- return true
- }
- return false
- })
- return
-}
-
-type childDeleteProxy struct {
- RemoteChatDeleteWithChildren
- child networkid.PortalKey
- done func()
-}
-
-func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context {
- return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children")
-}
-func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child }
-func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false }
-func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {}
-func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() }
-
func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult {
log := zerolog.Ctx(ctx)
if portal.Receiver == "" && evt.DeleteOnlyForMe() {
- ownUP, logins, err := portal.findOtherLogins(ctx, source)
+ logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
if err != nil {
log.Err(err).Msg("Failed to check if portal has other logins")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
+ var ownUP *database.UserPortal
+ logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool {
+ if up.LoginID == source.ID {
+ ownUP = up
+ return true
+ }
+ return false
+ })
if len(logins) > 0 {
log.Debug().Msg("Not deleting portal with other logins in remote chat delete event")
if ownUP != nil {
@@ -3886,47 +3066,22 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo
)
if err != nil {
log.Err(err).Msg("Failed to send leave state event for user after remote chat delete")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else {
log.Debug().Msg("Sent leave state event for user after remote chat delete")
return EventHandlingResultSuccess
}
}
}
- if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace {
- children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to fetch children to delete")
- return EventHandlingResultFailed.WithError(err)
- }
- log.Debug().
- Int("portal_count", len(children)).
- Msg("Deleting child portals before remote chat delete")
- var wg sync.WaitGroup
- wg.Add(len(children))
- for _, child := range children {
- child.queueEvent(ctx, &portalRemoteEvent{
- evt: &childDeleteProxy{
- RemoteChatDeleteWithChildren: childDeleter,
- child: child.PortalKey,
- done: wg.Done,
- },
- source: source,
- evtType: RemoteEventChatDelete,
- })
- }
- wg.Wait()
- log.Debug().Msg("Finished deleting child portals")
- }
err := portal.Delete(ctx)
if err != nil {
log.Err(err).Msg("Failed to delete portal from database")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
}
err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false)
if err != nil {
log.Err(err).Msg("Failed to delete Matrix room")
- return EventHandlingResultFailed.WithError(err)
+ return EventHandlingResultFailed
} else {
log.Info().Msg("Deleted room after remote chat delete event")
return EventHandlingResultSuccess
@@ -3973,43 +3128,12 @@ type PortalInfo = ChatInfo
type ChatMember struct {
EventSender
Membership event.Membership
- // Per-room nickname for the user. Not yet used.
- Nickname *string
- // The power level to set for the user when syncing power levels.
+ Nickname *string
PowerLevel *int
- // Optional user info to sync the ghost user while updating membership.
- UserInfo *UserInfo
- // The user who sent the membership change (user who invited/kicked/banned this user).
- // Not yet used. Not applicable if Membership is join or knock.
- MemberSender EventSender
- // Extra fields to include in the member event.
+ UserInfo *UserInfo
+
MemberEventExtra map[string]any
- // The expected previous membership. If this doesn't match, the change is ignored.
- PrevMembership event.Membership
-}
-
-type ChatMemberMap map[networkid.UserID]ChatMember
-
-// Set adds the given entry to this map, overwriting any existing entry with the same Sender field.
-func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap {
- if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe {
- return cmm
- }
- cmm[member.Sender] = member
- return cmm
-}
-
-// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists.
-// It returns true if the entry was added, false otherwise.
-func (cmm ChatMemberMap) Add(member ChatMember) bool {
- if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe {
- return false
- }
- if _, exists := cmm[member.Sender]; exists {
- return false
- }
- cmm[member.Sender] = member
- return true
+ PrevMembership event.Membership
}
type ChatMemberList struct {
@@ -4019,10 +3143,6 @@ type ChatMemberList struct {
// Should the bridge call IsThisUser for every member in the list?
// This should be used when SenderLogin can't be filled accurately.
CheckAllLogins bool
- // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default?
- // This is recommended for syncs with non-real-time changes.
- // Real-time changes (e.g. a user joining) should not set this flag set.
- ExcludeChangesFromTimeline bool
// The total number of members in the chat, regardless of how many of those members are included in MemberMap.
TotalMemberCount int
@@ -4033,7 +3153,7 @@ type ChatMemberList struct {
// Deprecated: Use MemberMap instead to avoid duplicate entries
Members []ChatMember
- MemberMap ChatMemberMap
+ MemberMap map[networkid.UserID]ChatMember
PowerLevels *PowerLevelOverrides
}
@@ -4135,11 +3255,9 @@ type ChatInfo struct {
Disappear *database.DisappearingSetting
ParentID *networkid.PortalID
- UserLocal *UserLocalPortalInfo
- MessageRequest *bool
- CanBackfill bool
+ UserLocal *UserLocalPortalInfo
- ExcludeChangesFromTimeline bool
+ CanBackfill bool
ExtraUpdates ExtraUpdater[*Portal]
}
@@ -4173,36 +3291,26 @@ type UserLocalPortalInfo struct {
Tag *event.RoomTag
}
-func (portal *Portal) updateName(
- ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool,
-) bool {
+func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool {
if portal.Name == name && (portal.NameSet || portal.MXID == "") {
return false
}
portal.Name = name
- portal.NameSet = portal.sendRoomMeta(
- ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil,
- )
+ portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name})
return true
}
-func (portal *Portal) updateTopic(
- ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool,
-) bool {
+func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool {
if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") {
return false
}
portal.Topic = topic
- portal.TopicSet = portal.sendRoomMeta(
- ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil,
- )
+ portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic})
return true
}
-func (portal *Portal) updateAvatar(
- ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool,
-) bool {
- if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") {
+func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool {
+ if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") {
return false
}
portal.AvatarID = avatar.ID
@@ -4218,15 +3326,13 @@ func (portal *Portal) updateAvatar(
portal.AvatarSet = false
zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar")
return true
- } else if newHash == portal.AvatarHash && portal.AvatarMXC != "" && portal.AvatarSet {
+ } else if newHash == portal.AvatarHash && portal.AvatarSet {
return true
}
portal.AvatarMXC = newMXC
portal.AvatarHash = newHash
}
- portal.AvatarSet = portal.sendRoomMeta(
- ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil,
- )
+ portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC})
return true
}
@@ -4257,11 +3363,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) {
Creator: portal.Bridge.Bot.GetMXID(),
Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(),
Channel: event.BridgeInfoSection{
- ID: string(portal.ID),
- DisplayName: portal.Name,
- AvatarURL: portal.AvatarMXC,
- Receiver: string(portal.Receiver),
- MessageRequest: portal.MessageRequest,
+ ID: string(portal.ID),
+ DisplayName: portal.Name,
+ AvatarURL: portal.AvatarMXC,
+ Receiver: string(portal.Receiver),
// TODO external URL?
},
BeeperRoomTypeV2: string(portal.RoomType),
@@ -4269,10 +3374,6 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) {
if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM {
bridgeInfo.BeeperRoomType = "dm"
}
- if bridgeInfo.Protocol.ID == "slackgo" {
- bridgeInfo.TempSlackRemoteIDMigratedFlag = true
- bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true
- }
parent := portal.GetTopLevelParent()
if parent != nil {
bridgeInfo.Network = &event.BridgeInfoSection{
@@ -4294,8 +3395,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) {
return
}
stateKey, bridgeInfo := portal.getBridgeInfo()
- portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil)
- portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil)
+ portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo)
+ portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo)
}
func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool {
@@ -4317,22 +3418,13 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin,
Str("old_id", portal.CapState.ID).
Str("new_id", capID).
Msg("Sending new room capability event")
- success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil)
+ success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps)
if !success {
return false
}
portal.CapState = database.CapabilityState{
Source: source.ID,
ID: capID,
- Flags: portal.CapState.Flags,
- }
- if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) {
- zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event")
- success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil)
- if !success {
- return false
- }
- portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet
}
portal.lastCapUpdate = time.Now()
if implicit {
@@ -4359,27 +3451,15 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri
return
}
-func (portal *Portal) sendRoomMeta(
- ctx context.Context,
- sender MatrixAPI,
- ts time.Time,
- eventType event.Type,
- stateKey string,
- content any,
- excludeFromTimeline bool,
- extra map[string]any,
-) bool {
+func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool {
if portal.MXID == "" {
return false
}
- if extra == nil {
- extra = make(map[string]any)
- }
- if excludeFromTimeline {
- extra["com.beeper.exclude_from_timeline"] = true
- }
+ var extra map[string]any
if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) {
- extra["fi.mau.implicit_name"] = true
+ extra = map[string]any{
+ "fi.mau.implicit_name": true,
+ }
}
_, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{
Parsed: content,
@@ -4391,55 +3471,9 @@ func (portal *Portal) sendRoomMeta(
Msg("Failed to set room metadata")
return false
}
- if eventType == event.StateBeeperDisappearingTimer {
- // TODO remove this debug log at some point
- zerolog.Ctx(ctx).Debug().
- Any("content", content).
- Msg("Sent new disappearing timer event")
- }
return true
}
-func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) {
- if !portal.Bridge.Config.RevertFailedStateChanges {
- return
- }
- if evt.GetStateKey() != "" && evt.Type != event.StateMember {
- return
- }
- switch evt.Type {
- case event.StateRoomName:
- portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil)
- case event.StateRoomAvatar:
- portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil)
- case event.StateTopic:
- portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil)
- case event.StateBeeperDisappearingTimer:
- portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil)
- case event.StateMember:
- var prevContent *event.MemberEventContent
- var extra map[string]any
- if evt.Unsigned.PrevContent != nil {
- _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
- prevContent = evt.Unsigned.PrevContent.AsMember()
- newContent := evt.Content.AsMember()
- if prevContent.Membership == newContent.Membership {
- return
- }
- extra = evt.Unsigned.PrevContent.Raw
- } else {
- prevContent = &event.MemberEventContent{Membership: event.MembershipLeave}
- }
- if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange {
- if extra == nil {
- extra = make(map[string]any)
- }
- extra["com.beeper.member_rollback"] = true
- portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra)
- }
- }
-}
-
func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) {
if members == nil {
invite = []id.UserID{source.UserMXID}
@@ -4523,39 +3557,6 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi
return false
}
-func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool {
- switch rule.JoinRule {
- case event.JoinRulePublic:
- return true
- case event.JoinRuleKnockRestricted, event.JoinRuleRestricted:
- for _, allow := range rule.Allow {
- if allow.Type == "fi.mau.spam_checker" {
- return true
- }
- }
- }
- return false
-}
-
-func (portal *Portal) roomIsPublic(ctx context.Context) bool {
- mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
- if !ok {
- return false
- }
- evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "")
- if err != nil {
- zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public")
- return false
- } else if evt == nil {
- return false
- }
- content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent)
- if !ok {
- return false
- }
- return looksDirectlyJoinable(content)
-}
-
func (portal *Portal) syncParticipants(
ctx context.Context,
members *ChatMemberList,
@@ -4586,12 +3587,6 @@ func (portal *Portal) syncParticipants(
}
delete(currentMembers, portal.Bridge.Bot.GetMXID())
powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower)
- addExcludeFromTimeline := func(raw map[string]any) {
- _, hasKey := raw["com.beeper.exclude_from_timeline"]
- if !hasKey && members.ExcludeChangesFromTimeline {
- raw["com.beeper.exclude_from_timeline"] = true
- }
- }
syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool {
if member.Membership == "" {
member.Membership = event.MembershipJoin
@@ -4621,10 +3616,12 @@ func (portal *Portal) syncParticipants(
Displayname: currentMember.Displayname,
AvatarURL: currentMember.AvatarURL,
}
- wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)}
- addExcludeFromTimeline(wrappedContent.Raw)
+ wrappedContent := &event.Content{Parsed: content, Raw: maps.Clone(member.MemberEventExtra)}
+ if wrappedContent.Raw == nil {
+ wrappedContent.Raw = make(map[string]any)
+ }
thisEvtSender := sender
- if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) {
+ if member.Membership == event.MembershipJoin {
content.Membership = event.MembershipInvite
if intent != nil {
wrappedContent.Raw["fi.mau.will_auto_accept"] = true
@@ -4654,11 +3651,7 @@ func (portal *Portal) syncParticipants(
currentMember.Membership = event.MembershipLeave
}
}
- if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID {
- _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts)
- } else {
- _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts)
- }
+ _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts)
if err != nil {
addLogContext(log.Err(err)).
Str("new_membership", string(content.Membership)).
@@ -4671,8 +3664,7 @@ func (portal *Portal) syncParticipants(
if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin {
content.Membership = event.MembershipJoin
- wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)}
- addExcludeFromTimeline(wrappedContent.Raw)
+ wrappedJoinContent := &event.Content{Parsed: content, Raw: member.MemberEventExtra}
_, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts)
if err != nil {
addLogContext(log.Err(err)).
@@ -4735,7 +3727,7 @@ func (portal *Portal) syncParticipants(
if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan {
continue
}
- if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) {
+ if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil {
continue
}
_, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{
@@ -4745,9 +3737,6 @@ func (portal *Portal) syncParticipants(
Displayname: memberEvt.Displayname,
Reason: "User is not in remote chat",
},
- Raw: map[string]any{
- "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline,
- },
}, time.Now())
if err != nil {
zerolog.Ctx(ctx).Err(err).
@@ -4816,28 +3805,16 @@ func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.M
return content
}
-type UpdateDisappearingSettingOpts struct {
- Sender MatrixAPI
- Timestamp time.Time
- Implicit bool
- Save bool
- SendNotice bool
-
- ExcludeFromTimeline bool
-}
-
-func (portal *Portal) UpdateDisappearingSetting(
- ctx context.Context,
- setting database.DisappearingSetting,
- opts UpdateDisappearingSettingOpts,
-) bool {
- setting = setting.Normalize()
+func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool {
+ if setting.Timer == 0 {
+ setting.Type = ""
+ }
if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type {
return false
}
portal.Disappear.Type = setting.Type
portal.Disappear.Timer = setting.Timer
- if opts.Save {
+ if save {
err := portal.Save(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting")
@@ -4846,45 +3823,19 @@ func (portal *Portal) UpdateDisappearingSetting(
if portal.MXID == "" {
return true
}
-
- if opts.Sender == nil {
- opts.Sender = portal.Bridge.Bot
+ content := DisappearingMessageNotice(setting.Timer, implicit)
+ if sender == nil {
+ sender = portal.Bridge.Bot
}
- if opts.Timestamp.IsZero() {
- opts.Timestamp = time.Now()
- }
- portal.sendRoomMeta(
- ctx,
- opts.Sender,
- opts.Timestamp,
- event.StateBeeperDisappearingTimer,
- "",
- setting.ToEventContent(),
- opts.ExcludeFromTimeline,
- nil,
- )
-
- if !opts.SendNotice {
- return true
- }
- content := DisappearingMessageNotice(setting.Timer, opts.Implicit)
- _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{
+ _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{
Parsed: content,
- Raw: map[string]any{
- "com.beeper.action_message": map[string]any{
- "type": "disappearing_timer",
- "timer": setting.Timer.Milliseconds(),
- "timer_type": setting.Type,
- "implicit": opts.Implicit,
- },
- },
- }, &MatrixSendExtra{Timestamp: opts.Timestamp})
+ }, &MatrixSendExtra{Timestamp: ts})
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice")
} else {
zerolog.Ctx(ctx).Debug().
Dur("new_timer", portal.Disappear.Timer).
- Bool("implicit", opts.Implicit).
+ Bool("implicit", implicit).
Msg("Sent disappearing messages notice")
}
return true
@@ -4946,13 +3897,13 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch
return
}
}
- changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed
+ changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed
changed = portal.updateAvatar(ctx, &Avatar{
ID: ghost.AvatarID,
MXC: ghost.AvatarMXC,
Hash: ghost.AvatarHash,
Remove: ghost.AvatarID == "",
- }, nil, time.Time{}, false) || changed
+ }, nil, time.Time{}) || changed
return
}
@@ -4961,36 +3912,28 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us
if info.Name == DefaultChatName {
if portal.NameIsCustom {
portal.NameIsCustom = false
- changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed
+ changed = portal.updateName(ctx, "", sender, ts) || changed
}
} else if info.Name != nil {
portal.NameIsCustom = true
- changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed
+ changed = portal.updateName(ctx, *info.Name, sender, ts) || changed
}
if info.Topic != nil {
- changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed
+ changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed
}
if info.Avatar != nil {
portal.NameIsCustom = true
- changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed
+ changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed
}
if info.Disappear != nil {
- changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{
- Sender: sender,
- Timestamp: ts,
- Implicit: false,
- Save: false,
-
- SendNotice: !info.ExcludeChangesFromTimeline,
- ExcludeFromTimeline: info.ExcludeChangesFromTimeline,
- }) || changed
+ changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed
}
if info.ParentID != nil {
changed = portal.updateParent(ctx, *info.ParentID, source) || changed
}
if info.JoinRule != nil {
// TODO change detection instead of spamming this every time?
- portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil)
+ portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule)
}
if info.Type != nil && portal.RoomType != *info.Type {
if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) {
@@ -5003,10 +3946,6 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us
portal.RoomType = *info.Type
}
}
- if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest {
- changed = true
- portal.MessageRequest = *info.MessageRequest
- }
if info.Members != nil && portal.MXID != "" && source != nil {
err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{})
if err != nil {
@@ -5048,9 +3987,6 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
}
return nil
}
- if portal.deleted.IsSet() {
- return ErrPortalIsDeleted
- }
waiter := make(chan struct{})
closed := false
evt := &portalCreateEvent{
@@ -5068,11 +4004,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
if PortalEventBuffer == 0 {
go portal.queueEvent(ctx, evt)
} else {
- select {
- case portal.events <- evt:
- case <-portal.deleted.GetChan():
- return ErrPortalIsDeleted
- }
+ portal.events <- evt
}
select {
case <-ctx.Done():
@@ -5083,11 +4015,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
}
func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error {
- cancellableCtx, cancel := context.WithCancel(ctx)
- defer cancel()
- portal.cancelRoomCreate.CompareAndSwap(nil, &cancel)
portal.roomCreateLock.Lock()
- portal.cancelRoomCreate.Store(&cancel)
defer portal.roomCreateLock.Unlock()
if portal.MXID != "" {
if source != nil {
@@ -5098,7 +4026,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
log := zerolog.Ctx(ctx).With().
Str("action", "create matrix room").
Logger()
- cancellableCtx = log.WithContext(cancellableCtx)
ctx = log.WithContext(ctx)
log.Info().Msg("Creating Matrix room")
@@ -5107,16 +4034,16 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
if info != nil {
log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info")
}
- info, err = source.Client.GetChatInfo(cancellableCtx, portal)
+ info, err = source.Client.GetChatInfo(ctx, portal)
if err != nil {
log.Err(err).Msg("Failed to update portal info for creation")
return err
}
}
- portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{})
- if cancellableCtx.Err() != nil {
- return cancellableCtx.Err()
+ portal.UpdateInfo(ctx, info, source, nil, time.Time{})
+ if ctx.Err() != nil {
+ return ctx.Err()
}
powerLevels := &event.PowerLevelsEventContent{
@@ -5129,7 +4056,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
portal.Bridge.Bot.GetMXID(): 9001,
},
}
- initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels)
+ initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels)
if err != nil {
log.Err(err).Msg("Failed to process participant list for portal creation")
return err
@@ -5138,12 +4065,15 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
req := mautrix.ReqCreateRoom{
Visibility: "private",
+ Name: portal.Name,
+ Topic: portal.Topic,
CreationContent: make(map[string]any),
InitialState: make([]*event.Event, 0, 6),
Preset: "private_chat",
IsDirect: portal.RoomType == database.RoomTypeDM,
PowerLevelOverride: powerLevels,
BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey),
+ RoomVersion: event.RoomV11,
}
autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites
if autoJoinInvites {
@@ -5156,7 +4086,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
req.CreationContent["type"] = event.RoomTypeSpace
}
bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo()
- roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal)
+ roomFeatures := source.Client.GetCapabilities(ctx, portal)
portal.CapState = database.CapabilityState{
Source: source.ID,
ID: roomFeatures.GetID(),
@@ -5179,47 +4109,19 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
StateKey: &bridgeInfoStateKey,
Type: event.StateBeeperRoomFeatures,
Content: event.Content{Parsed: roomFeatures},
- }, &event.Event{
- Type: event.StateTopic,
- Content: event.Content{
- Parsed: &event.TopicEventContent{Topic: portal.Topic},
- Raw: map[string]any{
- "com.beeper.exclude_from_timeline": true,
- },
- },
})
- if roomFeatures.DisappearingTimer != nil {
+ if req.Topic == "" {
+ // Add explicit topic event if topic is empty to ensure the event is set.
+ // This ensures that there won't be an extra event later if PUT /state/... is called.
req.InitialState = append(req.InitialState, &event.Event{
- Type: event.StateBeeperDisappearingTimer,
- Content: event.Content{
- Parsed: portal.Disappear.ToEventContent(),
- Raw: map[string]any{
- "com.beeper.exclude_from_timeline": true,
- },
- },
- })
- portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet
- }
- if portal.Name != "" {
- req.InitialState = append(req.InitialState, &event.Event{
- Type: event.StateRoomName,
- Content: event.Content{
- Parsed: &event.RoomNameEventContent{Name: portal.Name},
- Raw: map[string]any{
- "com.beeper.exclude_from_timeline": true,
- },
- },
+ Type: event.StateTopic,
+ Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}},
})
}
if portal.AvatarMXC != "" {
req.InitialState = append(req.InitialState, &event.Event{
- Type: event.StateRoomAvatar,
- Content: event.Content{
- Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC},
- Raw: map[string]any{
- "com.beeper.exclude_from_timeline": true,
- },
- },
+ Type: event.StateRoomAvatar,
+ Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}},
})
}
if portal.Parent != nil && portal.Parent.MXID != "" {
@@ -5238,9 +4140,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
Content: event.Content{Parsed: info.JoinRule},
})
}
- if cancellableCtx.Err() != nil {
- return cancellableCtx.Err()
- }
roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req)
if err != nil {
log.Err(err).Msg("Failed to create Matrix room")
@@ -5251,7 +4150,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
portal.TopicSet = true
portal.NameSet = true
portal.MXID = roomID
- portal.RoomCreated.Set()
portal.Bridge.cacheLock.Lock()
portal.Bridge.portalsByMXID[roomID] = portal
portal.Bridge.cacheLock.Unlock()
@@ -5298,55 +4196,42 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
}
}
}
- portal.addToUserSpaces(ctx)
- if info.CanBackfill &&
- portal.Bridge.Config.Backfill.Enabled &&
- portal.RoomType != database.RoomTypeSpace &&
- !portal.Bridge.Background {
+ if portal.Parent == nil {
+ if portal.Receiver != "" {
+ login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver)
+ if login != nil {
+ up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user portal to add portal to spaces")
+ } else {
+ login.inPortalCache.Remove(portal.PortalKey)
+ go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
+ }
+ }
+ } else {
+ userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces")
+ } else {
+ for _, up := range userPortals {
+ login := portal.Bridge.GetCachedUserLoginByID(up.LoginID)
+ if login != nil {
+ login.inPortalCache.Remove(portal.PortalKey)
+ go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
+ }
+ }
+ }
+ }
+ }
+ if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background {
portal.doForwardBackfill(ctx, source, nil, backfillBundle)
}
return nil
}
-func (portal *Portal) addToUserSpaces(ctx context.Context) {
- if portal.Parent != nil {
- return
- }
- log := zerolog.Ctx(ctx)
- withoutCancelCtx := log.WithContext(portal.Bridge.BackgroundCtx)
- if portal.Receiver != "" {
- login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver)
- if login != nil {
- up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to get user portal to add portal to spaces")
- } else {
- login.inPortalCache.Remove(portal.PortalKey)
- go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
- }
- }
- } else {
- userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey)
- if err != nil {
- log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces")
- } else {
- for _, up := range userPortals {
- login := portal.Bridge.GetCachedUserLoginByID(up.LoginID)
- if login != nil {
- login.inPortalCache.Remove(portal.PortalKey)
- go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues())
- }
- }
- }
- }
-}
-
func (portal *Portal) Delete(ctx context.Context) error {
- if portal.deleted.IsSet() {
- return nil
- }
portal.removeInPortalCache(ctx)
- err := portal.safeDBDelete(ctx)
+ err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
if err != nil {
return err
}
@@ -5356,21 +4241,11 @@ func (portal *Portal) Delete(ctx context.Context) error {
return nil
}
-func (portal *Portal) safeDBDelete(ctx context.Context) error {
- err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey)
- if err != nil {
- return fmt.Errorf("failed to delete messages in portal: %w", err)
- }
- // TODO delete child portals?
- return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
-}
-
func (portal *Portal) RemoveMXID(ctx context.Context) error {
if portal.MXID == "" {
return nil
}
portal.MXID = ""
- portal.RoomCreated.Clear()
err := portal.Save(ctx)
if err != nil {
return err
@@ -5403,10 +4278,8 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) {
}
func (portal *Portal) unlockedDelete(ctx context.Context) error {
- if portal.deleted.IsSet() {
- return nil
- }
- err := portal.safeDBDelete(ctx)
+ // TODO delete child portals?
+ err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
if err != nil {
return err
}
@@ -5415,14 +4288,10 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error {
}
func (portal *Portal) unlockedDeleteCache() {
- if portal.deleted.IsSet() {
- return
- }
delete(portal.Bridge.portalsByKey, portal.PortalKey)
if portal.MXID != "" {
delete(portal.Bridge.portalsByMXID, portal.MXID)
}
- portal.deleted.Set()
if portal.events != nil {
// TODO there's a small risk of this racing with a queueEvent call
close(portal.events)
@@ -5434,9 +4303,6 @@ func (portal *Portal) Save(ctx context.Context) error {
}
func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error {
- if portal.Receiver != "" && relay.ID != portal.Receiver {
- return fmt.Errorf("can't set non-receiver login as relay")
- }
portal.Relay = relay
if relay == nil {
portal.RelayLoginID = ""
diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go
index 879f07ae..9883fb12 100644
--- a/bridgev2/portalbackfill.go
+++ b/bridgev2/portalbackfill.go
@@ -194,9 +194,6 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t
if err != nil {
log.Err(err).Msg("Failed to get last thread message")
return
- } else if anchorMessage == nil {
- log.Warn().Msg("No messages found in thread?")
- return
}
resp := portal.fetchThreadBackfill(ctx, source, anchorMessage)
if resp != nil {
@@ -342,7 +339,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
for i, part := range msg.Parts {
partIDs = append(partIDs, part.ID)
portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent)
- part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent()
evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID)
dbMessage := &database.Message{
ID: msg.ID,
@@ -383,23 +379,19 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
prevThreadEvent.MXID = evtID
out.PrevThreadEvents[*msg.ThreadRoot] = evtID
}
- if msg.Disappear.Type != event.DisappearingTypeNone {
- if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
+ if msg.Disappear.Type != database.DisappearingTypeNone {
+ if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer)
}
out.Disappear = append(out.Disappear, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: evtID,
- Timestamp: msg.Timestamp,
DisappearingSetting: msg.Disappear,
})
}
}
slices.Sort(partIDs)
for _, reaction := range msg.Reactions {
- if reaction == nil {
- continue
- }
reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove)
if !ok {
continue
@@ -410,7 +402,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
if reaction.Timestamp.IsZero() {
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
}
- //lint:ignore SA4006 it's a todo
targetPart, ok := partMap[*reaction.TargetPart]
if !ok {
// TODO warning log and/or skip reaction?
diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go
index 4c7e2447..e82c481a 100644
--- a/bridgev2/portalinternal.go
+++ b/bridgev2/portalinternal.go
@@ -37,8 +37,8 @@ func (portal *PortalInternals) EventLoop() {
(*Portal)(portal).eventLoop()
}
-func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) {
- return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt)
+func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) {
+ return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt)
}
func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context {
@@ -49,10 +49,6 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any
(*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback)
}
-func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
- return (*Portal)(portal).unwrapBeeperSendState(ctx, evt)
-}
-
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) {
(*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID)
}
@@ -65,8 +61,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
}
-func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
- return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest)
+func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
}
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -77,10 +73,6 @@ func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user
(*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt)
}
-func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) {
- (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal)
-}
-
func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixTyping(ctx, evt)
}
@@ -125,24 +117,12 @@ func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.User
return (*Portal)(portal).getTargetUser(ctx, userID)
}
-func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt)
+func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt)
}
-func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
- return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest)
-}
-
-func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
- return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest)
-}
-
-func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixTombstone(ctx, evt)
-}
-
-func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) {
- (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser)
+func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt)
}
func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
@@ -153,10 +133,6 @@ func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *Us
return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt)
}
-func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) {
- (*Portal)(portal).ensureFunctionalMember(ctx, ghost)
-}
-
func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) {
return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType)
}
@@ -257,10 +233,6 @@ func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, sourc
return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt)
}
-func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) {
- return (*Portal)(portal).findOtherLogins(ctx, source)
-}
-
func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult {
return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt)
}
@@ -269,16 +241,16 @@ func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source
return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill)
}
-func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
- return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline)
+func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool {
+ return (*Portal)(portal).updateName(ctx, name, sender, ts)
}
-func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
- return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline)
+func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool {
+ return (*Portal)(portal).updateTopic(ctx, topic, sender, ts)
}
-func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
- return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline)
+func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool {
+ return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts)
}
func (portal *PortalInternals) GetBridgeInfoStateKey() string {
@@ -293,12 +265,8 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen
return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts)
}
-func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool {
- return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra)
-}
-
-func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) {
- (*Portal)(portal).revertRoomMeta(ctx, evt)
+func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool {
+ return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content)
}
func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) {
@@ -309,10 +277,6 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha
return (*Portal)(portal).updateOtherUser(ctx, members)
}
-func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool {
- return (*Portal)(portal).roomIsPublic(ctx)
-}
-
func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error {
return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts)
}
@@ -333,10 +297,6 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc
return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle)
}
-func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) {
- (*Portal)(portal).addToUserSpaces(ctx)
-}
-
func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) {
(*Portal)(portal).removeInPortalCache(ctx)
}
@@ -400,3 +360,7 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save
func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error {
return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove)
}
+
+func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
+ return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID)
+}
diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go
index c976d97c..a25fe820 100644
--- a/bridgev2/portalreid.go
+++ b/bridgev2/portalreid.go
@@ -32,40 +32,21 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
if source == target {
return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same")
}
- log := zerolog.Ctx(ctx).With().
- Str("action", "re-id portal").
- Stringer("source_portal_key", source).
- Stringer("target_portal_key", target).
- Logger()
- ctx = log.WithContext(ctx)
+ log := zerolog.Ctx(ctx)
+ log.Debug().Msg("Re-ID'ing portal")
defer func() {
log.Debug().Msg("Finished handling portal re-ID")
}()
- acquireCacheLock := func() {
- if !br.cacheLock.TryLock() {
- log.Debug().Msg("Waiting for global cache lock")
- br.cacheLock.Lock()
- log.Debug().Msg("Acquired global cache lock after waiting")
- } else {
- log.Trace().Msg("Acquired global cache lock without waiting")
- }
- }
- log.Debug().Msg("Re-ID'ing portal")
- sourcePortal, err := br.GetExistingPortalByKey(ctx, source)
+ br.cacheLock.Lock()
+ defer br.cacheLock.Unlock()
+ sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err)
} else if sourcePortal == nil {
log.Debug().Msg("Source portal not found, re-ID is no-op")
return ReIDResultNoOp, nil, nil
}
- if !sourcePortal.roomCreateLock.TryLock() {
- if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
- (*cancelCreate)()
- }
- log.Debug().Msg("Waiting for source portal room creation lock")
- sourcePortal.roomCreateLock.Lock()
- log.Debug().Msg("Acquired source portal room creation lock after waiting")
- }
+ sourcePortal.roomCreateLock.Lock()
defer sourcePortal.roomCreateLock.Unlock()
if sourcePortal.MXID == "" {
log.Info().Msg("Source portal doesn't have Matrix room, deleting row")
@@ -78,37 +59,22 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Stringer("source_portal_mxid", sourcePortal.MXID)
})
-
- acquireCacheLock()
targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true)
if err != nil {
- br.cacheLock.Unlock()
return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err)
}
if targetPortal == nil {
log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal")
err = sourcePortal.unlockedReID(ctx, target)
- br.cacheLock.Unlock()
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err)
}
return ReIDResultSourceReIDd, sourcePortal, nil
}
- br.cacheLock.Unlock()
-
- if !targetPortal.roomCreateLock.TryLock() {
- if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
- (*cancelCreate)()
- }
- log.Debug().Msg("Waiting for target portal room creation lock")
- targetPortal.roomCreateLock.Lock()
- log.Debug().Msg("Acquired target portal room creation lock after waiting")
- }
+ targetPortal.roomCreateLock.Lock()
defer targetPortal.roomCreateLock.Unlock()
if targetPortal.MXID == "" {
log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal")
- acquireCacheLock()
- defer br.cacheLock.Unlock()
err = targetPortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err)
@@ -123,9 +89,6 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
return c.Stringer("target_portal_mxid", targetPortal.MXID)
})
log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal")
- sourcePortal.removeInPortalCache(ctx)
- acquireCacheLock()
- defer br.cacheLock.Unlock()
err = sourcePortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err)
@@ -133,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
go func() {
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
Parsed: &event.TombstoneEventContent{
- Body: "This room has been merged",
+ Body: fmt.Sprintf("This room has been merged"),
ReplacementRoom: targetPortal.MXID,
},
}, time.Now())
diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go
deleted file mode 100644
index 72bacaff..00000000
--- a/bridgev2/provisionutil/creategroup.go
+++ /dev/null
@@ -1,149 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package provisionutil
-
-import (
- "context"
-
- "github.com/rs/zerolog"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/bridgev2"
- "maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-type RespCreateGroup struct {
- ID networkid.PortalID `json:"id"`
- MXID id.RoomID `json:"mxid"`
- Portal *bridgev2.Portal `json:"-"`
-
- FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"`
-}
-
-func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) {
- api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI)
- if !ok {
- return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups"))
- }
- zerolog.Ctx(ctx).Debug().
- Any("create_params", params).
- Msg("Creating group chat on remote network")
- caps := login.Bridge.Network.GetCapabilities()
- typeSpec, validType := caps.Provisioning.GroupCreation[params.Type]
- if !validType {
- return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type))
- }
- if len(params.Participants) < typeSpec.Participants.MinLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength))
- } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength))
- }
- userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
- for i, participant := range params.Participants {
- parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant))
- if ok {
- participant = parsedParticipant
- params.Participants[i] = participant
- }
- if !typeSpec.Participants.SkipIdentifierValidation {
- if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant))
- }
- }
- if api.IsThisUser(ctx, participant) {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant))
- }
- }
- if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required"))
- } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength))
- } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength))
- }
- if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required"))
- }
- if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required"))
- } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength))
- } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength))
- }
- if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required"))
- } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer"))
- }
- if params.Username == "" && typeSpec.Username.Required {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required"))
- } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength))
- } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength))
- }
- if params.Parent == nil && typeSpec.Parent.Required {
- return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required"))
- }
- resp, err := api.CreateGroup(ctx, params)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to create group")
- return nil, err
- }
- if resp.PortalKey.IsEmpty() {
- return nil, ErrNoPortalKey
- }
- zerolog.Ctx(ctx).Debug().
- Object("portal_key", resp.PortalKey).
- Msg("Successfully created group on remote network")
- if resp.Portal == nil {
- resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
- return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
- }
- }
- if resp.Portal.MXID == "" {
- err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
- return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
- }
- }
- for key, fp := range resp.FailedParticipants {
- if fp.InviteEventType == "" {
- fp.InviteEventType = event.EventMessage.Type
- }
- if fp.UserMXID == "" {
- ghost, err := login.Bridge.GetGhostByID(ctx, key)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant")
- } else if ghost != nil {
- fp.UserMXID = ghost.Intent.GetMXID()
- }
- }
- if fp.DMRoomMXID == "" {
- portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant")
- } else if portal != nil {
- fp.DMRoomMXID = portal.MXID
- }
- }
- }
- return &RespCreateGroup{
- ID: resp.Portal.ID,
- MXID: resp.Portal.MXID,
- Portal: resp.Portal,
-
- FailedParticipants: resp.FailedParticipants,
- }, nil
-}
diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go
deleted file mode 100644
index ce163e67..00000000
--- a/bridgev2/provisionutil/listcontacts.go
+++ /dev/null
@@ -1,98 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package provisionutil
-
-import (
- "context"
-
- "github.com/rs/zerolog"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/bridgev2"
-)
-
-type RespGetContactList struct {
- Contacts []*RespResolveIdentifier `json:"contacts"`
-}
-
-type RespSearchUsers struct {
- Results []*RespResolveIdentifier `json:"results"`
-}
-
-func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) {
- api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
- if !ok {
- return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts"))
- }
- resp, err := api.GetContactList(ctx)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
- return nil, err
- }
- return &RespGetContactList{
- Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false),
- }, nil
-}
-
-func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) {
- api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
- if !ok {
- return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users"))
- }
- resp, err := api.SearchUsers(ctx, query)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
- return nil, err
- }
- return &RespSearchUsers{
- Results: processResolveIdentifiers(ctx, login.Bridge, resp, true),
- }, nil
-}
-
-func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) {
- apiResp = make([]*RespResolveIdentifier, len(resp))
- for i, contact := range resp {
- apiContact := &RespResolveIdentifier{
- ID: contact.UserID,
- }
- apiResp[i] = apiContact
- if contact.UserInfo != nil {
- if contact.UserInfo.Name != nil {
- apiContact.Name = *contact.UserInfo.Name
- }
- if contact.UserInfo.Identifiers != nil {
- apiContact.Identifiers = contact.UserInfo.Identifiers
- }
- }
- if contact.Ghost != nil {
- if syncInfo && contact.UserInfo != nil {
- contact.Ghost.UpdateInfo(ctx, contact.UserInfo)
- }
- if contact.Ghost.Name != "" {
- apiContact.Name = contact.Ghost.Name
- }
- if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
- apiContact.Identifiers = contact.Ghost.Identifiers
- }
- apiContact.AvatarURL = contact.Ghost.AvatarMXC
- apiContact.MXID = contact.Ghost.Intent.GetMXID()
- }
- if contact.Chat != nil {
- if contact.Chat.Portal == nil {
- var err error
- contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
- }
- }
- if contact.Chat.Portal != nil {
- apiContact.DMRoomID = contact.Chat.Portal.MXID
- }
- }
- }
- return
-}
diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go
deleted file mode 100644
index cfc388d0..00000000
--- a/bridgev2/provisionutil/resolveidentifier.go
+++ /dev/null
@@ -1,125 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package provisionutil
-
-import (
- "context"
- "errors"
-
- "github.com/rs/zerolog"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/bridgev2"
- "maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/id"
-)
-
-type RespResolveIdentifier struct {
- ID networkid.UserID `json:"id"`
- Name string `json:"name,omitempty"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Identifiers []string `json:"identifiers,omitempty"`
- MXID id.UserID `json:"mxid,omitempty"`
- DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
-
- Portal *bridgev2.Portal `json:"-"`
- Ghost *bridgev2.Ghost `json:"-"`
- JustCreated bool `json:"-"`
-}
-
-var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request")
-
-func ResolveIdentifier(
- ctx context.Context,
- login *bridgev2.UserLogin,
- identifier string,
- createChat bool,
-) (*RespResolveIdentifier, error) {
- api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
- if !ok {
- return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers"))
- }
- var resp *bridgev2.ResolveIdentifierResponse
- parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier))
- validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
- if ok && (!vOK || validator.ValidateUserID(parsedUserID)) {
- ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID")
- return nil, err
- }
- resp = &bridgev2.ResolveIdentifierResponse{
- Ghost: ghost,
- UserID: parsedUserID,
- }
- gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI)
- if ok && createChat {
- resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat")
- return nil, err
- }
- } else if createChat || ghost.Name == "" {
- zerolog.Ctx(ctx).Debug().
- Bool("create_chat", createChat).
- Bool("has_name", ghost.Name != "").
- Msg("Falling back to resolving identifier")
- resp = nil
- identifier = string(parsedUserID)
- }
- }
- if resp == nil {
- var err error
- resp, err = api.ResolveIdentifier(ctx, identifier, createChat)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier")
- return nil, err
- } else if resp == nil {
- return nil, nil
- }
- }
- apiResp := &RespResolveIdentifier{
- ID: resp.UserID,
- Ghost: resp.Ghost,
- }
- if resp.Ghost != nil {
- if resp.UserInfo != nil {
- resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
- }
- apiResp.Name = resp.Ghost.Name
- apiResp.AvatarURL = resp.Ghost.AvatarMXC
- apiResp.Identifiers = resp.Ghost.Identifiers
- apiResp.MXID = resp.Ghost.Intent.GetMXID()
- } else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
- apiResp.Name = *resp.UserInfo.Name
- }
- if resp.Chat != nil {
- if resp.Chat.PortalKey.IsEmpty() {
- return nil, ErrNoPortalKey
- }
- if resp.Chat.Portal == nil {
- var err error
- resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
- return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
- }
- }
- resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID)
- if createChat && resp.Chat.Portal.MXID == "" {
- apiResp.JustCreated = true
- err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo)
- if err != nil {
- zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
- return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
- }
- }
- apiResp.Portal = resp.Chat.Portal
- apiResp.DMRoomID = resp.Chat.Portal.MXID
- }
- return apiResp, nil
-}
diff --git a/bridgev2/queue.go b/bridgev2/queue.go
index 3775c825..04d982b5 100644
--- a/bridgev2/queue.go
+++ b/bridgev2/queue.go
@@ -63,13 +63,6 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve
return true
}
-var (
- ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
- ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
- ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage())
- ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage()
-)
-
func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult {
// TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands
@@ -85,11 +78,13 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
return EventHandlingResultFailed
} else if sender == nil {
log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event")
- br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
+ status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
+ br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
return EventHandlingResultFailed
} else if !sender.Permissions.SendEvents {
if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") {
- br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt))
+ status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
+ br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
}
return EventHandlingResultIgnored
} else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") {
@@ -97,7 +92,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
}
} else if evt.Type.Class != event.EphemeralEventType {
log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event")
- br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
+ status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
+ br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
return EventHandlingResultIgnored
}
if evt.Type == event.EventMessage && sender != nil {
@@ -106,10 +102,11 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
msg.RemovePerMessageProfileFallback()
if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom {
if !sender.Permissions.Commands {
- br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt))
+ status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
+ br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
return EventHandlingResultIgnored
}
- go br.Commands.Handle(
+ br.Commands.Handle(
ctx,
evt.RoomID,
evt.ID,
@@ -117,7 +114,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "),
msg.RelatesTo.GetReplyTo(),
)
- return EventHandlingResultQueued
+ return EventHandlingResultSuccess
}
}
if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil {
@@ -160,27 +157,10 @@ type EventHandlingResult struct {
Ignored bool
Queued bool
- SkipStateEcho bool
-
// Error is an optional reason for failure. It is not required, Success may be false even without a specific error.
Error error
// Whether the Error should be sent as a MSS event.
SendMSS bool
-
- // EventID from the network
- EventID id.EventID
- // Stream order from the network
- StreamOrder int64
-}
-
-func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult {
- ehr.EventID = id
- return ehr
-}
-
-func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult {
- ehr.StreamOrder = order
- return ehr
}
func (ehr EventHandlingResult) WithError(err error) EventHandlingResult {
@@ -197,11 +177,6 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult {
return ehr
}
-func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult {
- ehr.SkipStateEcho = skip
- return ehr
-}
-
func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult {
if err == nil {
return ehr
@@ -220,7 +195,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult {
return ul.Bridge.QueueRemoteEvent(ul, evt)
}
-func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult {
+func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) {
log := login.Log
ctx := log.WithContext(br.BackgroundCtx)
maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver)
@@ -236,14 +211,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandl
if err != nil {
log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain).
Msg("Failed to get portal to handle remote event")
- return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err))
+ return
} else if portal == nil {
log.Warn().
Stringer("event_type", evt.GetType()).
Object("portal_key", key).
Bool("uncertain_receiver", isUncertain).
Msg("Portal not found to handle remote event")
- return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler)
+ return
}
// TODO put this in a better place, and maybe cache to avoid constant db queries
login.MarkInPortal(ctx, portal)
diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go
index 56e3a6b1..c725141b 100644
--- a/bridgev2/simplevent/chat.go
+++ b/bridgev2/simplevent/chat.go
@@ -65,19 +65,14 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal)
type ChatDelete struct {
EventMeta
OnlyForMe bool
- Children bool
}
-var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil)
+var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil)
func (evt *ChatDelete) DeleteOnlyForMe() bool {
return evt.OnlyForMe
}
-func (evt *ChatDelete) DeleteChildren() bool {
- return evt.Children
-}
-
// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange].
type ChatInfoChange struct {
EventMeta
diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go
index f8f8d7e1..f648ab12 100644
--- a/bridgev2/simplevent/message.go
+++ b/bridgev2/simplevent/message.go
@@ -59,41 +59,6 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID {
return evt.TransactionID
}
-// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data.
-type PreConvertedMessage struct {
- EventMeta
- Data *bridgev2.ConvertedMessage
- ID networkid.MessageID
- TransactionID networkid.TransactionID
-
- HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error)
-}
-
-var (
- _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil)
- _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil)
- _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil)
-)
-
-func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) {
- return evt.Data, nil
-}
-
-func (evt *PreConvertedMessage) GetID() networkid.MessageID {
- return evt.ID
-}
-
-func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID {
- return evt.TransactionID
-}
-
-func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) {
- if evt.HandleExistingFunc == nil {
- return bridgev2.UpsertResult{}, nil
- }
- return evt.HandleExistingFunc(ctx, portal, intent, existing)
-}
-
type MessageRemove struct {
EventMeta
diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go
index 96c8a9c5..8aa91866 100644
--- a/bridgev2/simplevent/meta.go
+++ b/bridgev2/simplevent/meta.go
@@ -27,9 +27,8 @@ type EventMeta struct {
Timestamp time.Time
StreamOrder int64
- PreHandleFunc func(context.Context, *bridgev2.Portal)
- PostHandleFunc func(context.Context, *bridgev2.Portal)
- MutateContextFunc func(context.Context) context.Context
+ PreHandleFunc func(context.Context, *bridgev2.Portal)
+ PostHandleFunc func(context.Context, *bridgev2.Portal)
}
var (
@@ -40,7 +39,6 @@ var (
_ bridgev2.RemoteEventWithStreamOrder = (*EventMeta)(nil)
_ bridgev2.RemotePreHandler = (*EventMeta)(nil)
_ bridgev2.RemotePostHandler = (*EventMeta)(nil)
- _ bridgev2.RemoteEventWithContextMutation = (*EventMeta)(nil)
)
func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context {
@@ -93,13 +91,6 @@ func (evt *EventMeta) PostHandle(ctx context.Context, portal *bridgev2.Portal) {
}
}
-func (evt *EventMeta) MutateContext(ctx context.Context) context.Context {
- if evt.MutateContextFunc == nil {
- return ctx
- }
- return evt.MutateContextFunc(ctx)
-}
-
func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta {
evt.Type = t
return evt
@@ -110,18 +101,6 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E
return evt
}
-func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta {
- origFunc := evt.LogContext
- if origFunc == nil {
- evt.LogContext = f
- return evt
- }
- evt.LogContext = func(c zerolog.Context) zerolog.Context {
- return f(origFunc(c))
- }
- return evt
-}
-
func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta {
evt.PortalKey = p
return evt
diff --git a/bridgev2/space.go b/bridgev2/space.go
index 2ca2bce3..ccb74b26 100644
--- a/bridgev2/space.go
+++ b/bridgev2/space.go
@@ -164,17 +164,14 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) {
ul.UserMXID: 50,
},
},
- Invite: []id.UserID{ul.UserMXID},
+ RoomVersion: event.RoomV11,
+ Invite: []id.UserID{ul.UserMXID},
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{ul.UserMXID}
// TODO remove this after initial_members is supported in hungryserv
req.BeeperAutoJoinInvites = true
}
- pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI)
- if ok {
- pfc.CustomizePersonalFilteringSpace(req)
- }
ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to create space room: %w", err)
diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go
index 5925dd4f..01a235a0 100644
--- a/bridgev2/status/bridgestate.go
+++ b/bridgev2/status/bridgestate.go
@@ -19,10 +19,9 @@ import (
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
+ "go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
- "maunium.net/go/mautrix/bridgev2/networkid"
- "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -88,8 +87,6 @@ type RemoteProfile struct {
Username string `json:"username,omitempty"`
Name string `json:"name,omitempty"`
Avatar id.ContentURIString `json:"avatar,omitempty"`
-
- AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"`
}
func coalesce[T ~string](a, b T) T {
@@ -105,14 +102,11 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile {
other.Username = coalesce(rp.Username, other.Username)
other.Name = coalesce(rp.Name, other.Name)
other.Avatar = coalesce(rp.Avatar, other.Avatar)
- if rp.AvatarFile != nil {
- other.AvatarFile = rp.AvatarFile
- }
return other
}
-func (rp *RemoteProfile) IsZero() bool {
- return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil)
+func (rp *RemoteProfile) IsEmpty() bool {
+ return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "")
}
type BridgeState struct {
@@ -126,10 +120,10 @@ type BridgeState struct {
UserAction BridgeStateUserAction `json:"user_action,omitempty"`
- UserID id.UserID `json:"user_id,omitempty"`
- RemoteID networkid.UserLoginID `json:"remote_id,omitempty"`
- RemoteName string `json:"remote_name,omitempty"`
- RemoteProfile RemoteProfile `json:"remote_profile,omitzero"`
+ UserID id.UserID `json:"user_id,omitempty"`
+ RemoteID string `json:"remote_id,omitempty"`
+ RemoteName string `json:"remote_name,omitempty"`
+ RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"`
Reason string `json:"reason,omitempty"`
Info map[string]interface{} `json:"info,omitempty"`
@@ -209,7 +203,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool {
pong.StateEvent == newPong.StateEvent &&
pong.RemoteName == newPong.RemoteName &&
pong.UserAction == newPong.UserAction &&
- pong.RemoteProfile == newPong.RemoteProfile &&
+ ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) &&
pong.Error == newPong.Error &&
maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) &&
pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now())
diff --git a/bridgev2/status/messagecheckpoint.go b/bridgev2/status/messagecheckpoint.go
index b3c05f4f..ea859b84 100644
--- a/bridgev2/status/messagecheckpoint.go
+++ b/bridgev2/status/messagecheckpoint.go
@@ -169,13 +169,13 @@ type CheckpointsJSON struct {
Checkpoints []*MessageCheckpoint `json:"checkpoints"`
}
-func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpoint string, token string) error {
+func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error {
var body bytes.Buffer
if err := json.NewEncoder(&body).Encode(cj); err != nil {
return fmt.Errorf("failed to encode message checkpoint JSON: %w", err)
}
- ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &body)
if err != nil {
@@ -186,10 +186,7 @@ func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpo
req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (checkpoint sender)")
req.Header.Set("Content-Type", "application/json")
- if cli == nil {
- cli = http.DefaultClient
- }
- resp, err := cli.Do(req)
+ resp, err := http.DefaultClient.Do(req)
if err != nil {
return mautrix.HTTPError{
Request: req,
diff --git a/bridgev2/user.go b/bridgev2/user.go
index 9a7896d6..350cecd1 100644
--- a/bridgev2/user.go
+++ b/bridgev2/user.go
@@ -176,10 +176,6 @@ func (user *User) GetUserLogins() []*UserLogin {
return maps.Values(user.logins)
}
-func (user *User) HasTooManyLogins() bool {
- return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins
-}
-
func (user *User) GetFormattedUserLogins() string {
user.Bridge.cacheLock.Lock()
logins := make([]string, len(user.logins))
@@ -229,8 +225,9 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) {
user.MXID: 50,
},
},
- Invite: []id.UserID{user.MXID},
- IsDirect: true,
+ RoomVersion: event.RoomV11,
+ Invite: []id.UserID{user.MXID},
+ IsDirect: true,
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{user.MXID}
diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go
index d56dc4cc..203dc122 100644
--- a/bridgev2/userlogin.go
+++ b/bridgev2/userlogin.go
@@ -10,7 +10,6 @@ import (
"cmp"
"context"
"fmt"
- "maps"
"slices"
"sync"
"time"
@@ -51,8 +50,6 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
- // TODO if loading the user caused the provided userlogin to be loaded, cancel here?
- // Currently this will double-load it
}
userLogin := &UserLogin{
UserLogin: dbUserLogin,
@@ -143,12 +140,6 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin {
return br.userLoginsByID[id]
}
-func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) {
- br.cacheLock.Lock()
- defer br.cacheLock.Unlock()
- return slices.Collect(maps.Values(br.userLoginsByID))
-}
-
func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
@@ -510,9 +501,9 @@ var _ status.BridgeStateFiller = (*UserLogin)(nil)
func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState {
state.UserID = ul.UserMXID
- state.RemoteID = ul.ID
+ state.RemoteID = string(ul.ID)
state.RemoteName = ul.RemoteName
- state.RemoteProfile = ul.RemoteProfile
+ state.RemoteProfile = &ul.RemoteProfile
filler, ok := ul.Client.(status.BridgeStateFiller)
if ok {
return filler.FillBridgeState(state)
diff --git a/client.go b/client.go
index 045d7b8e..6f746015 100644
--- a/client.go
+++ b/client.go
@@ -13,7 +13,6 @@ import (
"net/http"
"net/url"
"os"
- "runtime"
"slices"
"strconv"
"strings"
@@ -111,8 +110,6 @@ type Client struct {
// Set to true to disable automatically sleeping on 429 errors.
IgnoreRateLimit bool
- ResponseSizeLimit int64
-
txnID int32
// Should the ?user_id= query parameter be set in requests?
@@ -142,12 +139,6 @@ type IdentityServerInfo struct {
// Use ParseUserID to extract the server name from a user ID.
// https://spec.matrix.org/v1.2/client-server-api/#server-discovery
func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) {
- return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName)
-}
-
-const WellKnownMaxSize = 64 * 1024
-
-func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) {
wellKnownURL := url.URL{
Scheme: "https",
Host: serverName,
@@ -159,11 +150,10 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve
return nil, err
}
- if runtime.GOOS != "js" {
- req.Header.Set("Accept", "application/json")
- req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)")
- }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)")
+ client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
@@ -172,15 +162,11 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve
if resp.StatusCode == http.StatusNotFound {
return nil, nil
- } else if resp.ContentLength > WellKnownMaxSize {
- return nil, errors.New(".well-known response too large")
}
- data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize))
+ data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
- } else if len(data) >= WellKnownMaxSize {
- return nil, errors.New(".well-known response too large")
}
var wellKnown ClientWellKnown
@@ -331,7 +317,6 @@ const (
LogBodyContextKey contextKey = iota
LogRequestIDContextKey
MaxAttemptsContextKey
- SyncTokenContextKey
)
func (cli *Client) RequestStart(req *http.Request) {
@@ -386,14 +371,7 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
}
}
if body := req.Context().Value(LogBodyContextKey); body != nil {
- switch typedLogBody := body.(type) {
- case json.RawMessage:
- evt.RawJSON("req_body", typedLogBody)
- case string:
- evt.Str("req_body", typedLogBody)
- default:
- panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body))
- }
+ evt.Interface("req_body", body)
}
if errors.Is(err, context.Canceled) {
evt.Msg("Request canceled")
@@ -410,43 +388,32 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin
return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody})
}
-type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error)
+type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error)
type FullRequest struct {
- Method string
- URL string
- Headers http.Header
- RequestJSON interface{}
- RequestBytes []byte
- RequestBody io.Reader
- RequestLength int64
- ResponseJSON interface{}
- MaxAttempts int
- BackoffDuration time.Duration
- SensitiveContent bool
- Handler ClientResponseHandler
- DontReadResponse bool
- ResponseSizeLimit int64
- Logger *zerolog.Logger
- Client *http.Client
+ Method string
+ URL string
+ Headers http.Header
+ RequestJSON interface{}
+ RequestBytes []byte
+ RequestBody io.Reader
+ RequestLength int64
+ ResponseJSON interface{}
+ MaxAttempts int
+ BackoffDuration time.Duration
+ SensitiveContent bool
+ Handler ClientResponseHandler
+ DontReadResponse bool
+ Logger *zerolog.Logger
+ Client *http.Client
}
var requestID int32
var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes"
func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) {
- reqID := atomic.AddInt32(&requestID, 1)
- logger := zerolog.Ctx(ctx)
- if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
- logger = params.Logger
- }
- ctx = logger.With().
- Int32("req_id", reqID).
- Logger().WithContext(ctx)
-
var logBody any
- var reqBody io.Reader
- var reqLen int64
+ reqBody := params.RequestBody
if params.RequestJSON != nil {
jsonStr, err := json.Marshal(params.RequestJSON)
if err != nil {
@@ -457,38 +424,33 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
}
if params.SensitiveContent && !logSensitiveContent {
logBody = ""
- } else if len(jsonStr) > 32768 {
- logBody = fmt.Sprintf("", len(jsonStr))
} else {
- logBody = json.RawMessage(jsonStr)
+ logBody = params.RequestJSON
}
reqBody = bytes.NewReader(jsonStr)
- reqLen = int64(len(jsonStr))
} else if params.RequestBytes != nil {
logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes))
reqBody = bytes.NewReader(params.RequestBytes)
- reqLen = int64(len(params.RequestBytes))
- } else if params.RequestBody != nil {
- logBody = ""
- reqLen = -1
- if params.RequestLength > 0 {
- logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
- reqLen = params.RequestLength
- } else if params.RequestLength == 0 {
- zerolog.Ctx(ctx).Warn().
- Msg("RequestBody passed without specifying request length")
- }
- reqBody = params.RequestBody
+ params.RequestLength = int64(len(params.RequestBytes))
+ } else if params.RequestLength > 0 && params.RequestBody != nil {
+ logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok {
// Prevent HTTP from closing the request body, it might be needed for retries
reqBody = nopCloseSeeker{rsc}
}
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
params.RequestJSON = struct{}{}
- logBody = json.RawMessage("{}")
+ logBody = params.RequestJSON
reqBody = bytes.NewReader([]byte("{}"))
- reqLen = 2
}
+ reqID := atomic.AddInt32(&requestID, 1)
+ logger := zerolog.Ctx(ctx)
+ if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
+ logger = params.Logger
+ }
+ ctx = logger.With().
+ Int32("req_id", reqID).
+ Logger().WithContext(ctx)
ctx = context.WithValue(ctx, LogBodyContextKey, logBody)
ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID))
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
@@ -504,7 +466,9 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
if params.RequestJSON != nil {
req.Header.Set("Content-Type", "application/json")
}
- req.ContentLength = reqLen
+ if params.RequestLength > 0 && params.RequestBody != nil {
+ req.ContentLength = params.RequestLength
+ }
return req, nil
}
@@ -549,31 +513,14 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque
params.Handler = handleNormalResponse
}
}
- if cli.UserAgent != "" {
- req.Header.Set("User-Agent", cli.UserAgent)
- }
+ req.Header.Set("User-Agent", cli.UserAgent)
if len(cli.AccessToken) > 0 {
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
}
- if params.ResponseSizeLimit == 0 {
- params.ResponseSizeLimit = cli.ResponseSizeLimit
- }
- if params.ResponseSizeLimit == 0 {
- params.ResponseSizeLimit = DefaultResponseSizeLimit
- }
if params.Client == nil {
params.Client = cli.Client
}
- return cli.executeCompiledRequest(
- req,
- params.MaxAttempts-1,
- params.BackoffDuration,
- params.ResponseJSON,
- params.Handler,
- params.DontReadResponse,
- params.ResponseSizeLimit,
- params.Client,
- )
+ return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client)
}
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
@@ -584,17 +531,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
-func (cli *Client) doRetry(
- req *http.Request,
- cause error,
- retries int,
- backoff time.Duration,
- responseJSON any,
- handler ClientResponseHandler,
- dontReadResponse bool,
- sizeLimit int64,
- client *http.Client,
-) ([]byte, *http.Response, error) {
+func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
var err error
@@ -616,37 +553,21 @@ func (cli *Client) doRetry(
}
}
log.Warn().Err(cause).
- Str("method", req.Method).
- Str("url", req.URL.String()).
Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Request failed, retrying")
select {
case <-time.After(backoff):
case <-req.Context().Done():
- if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) {
- return nil, nil, req.Context().Err()
- }
+ return nil, nil, req.Context().Err()
}
if cli.UpdateRequestOnRetry != nil {
req = cli.UpdateRequestOnRetry(req, cause)
}
- return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client)
+ return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client)
}
-func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) {
- if res.ContentLength > limit {
- return nil, HTTPError{
- Request: req,
- Response: res,
-
- Message: "not reading response",
- WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
- }
- }
- contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1))
- if err == nil && len(contents) > int(limit) {
- err = ErrBodyReadReachedLimit
- }
+func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) {
+ contents, err := io.ReadAll(res.Body)
if err != nil {
return nil, HTTPError{
Request: req,
@@ -667,20 +588,17 @@ func closeTemp(log *zerolog.Logger, file *os.File) {
}
}
-func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
+func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
log := zerolog.Ctx(req.Context())
file, err := os.CreateTemp("", "mautrix-response-")
if err != nil {
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
- _, err = handleNormalResponse(req, res, responseJSON, limit)
+ _, err = handleNormalResponse(req, res, responseJSON)
return nil, err
}
defer closeTemp(log, file)
- var n int64
- if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil {
+ if _, err = io.Copy(file, res.Body); err != nil {
return nil, fmt.Errorf("failed to copy response to file: %w", err)
- } else if n > limit {
- return nil, ErrBodyReadReachedLimit
} else if _, err = file.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err)
} else if err = json.NewDecoder(file).Decode(responseJSON); err != nil {
@@ -690,12 +608,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON any, lim
}
}
-func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
+func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
return nil, nil
}
-func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
- if contents, err := readResponseBody(req, res, limit); err != nil {
+func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
+ if contents, err := readResponseBody(req, res); err != nil {
return nil, err
} else if responseJSON == nil {
return contents, nil
@@ -713,13 +631,8 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON an
}
}
-const ErrorResponseSizeLimit = 512 * 1024
-
-var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024
-
func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
- defer res.Body.Close()
- contents, err := readResponseBody(req, res, ErrorResponseSizeLimit)
+ contents, err := readResponseBody(req, res)
if err != nil {
return contents, err
}
@@ -738,31 +651,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
-func (cli *Client) executeCompiledRequest(
- req *http.Request,
- retries int,
- backoff time.Duration,
- responseJSON any,
- handler ClientResponseHandler,
- dontReadResponse bool,
- sizeLimit int64,
- client *http.Client,
-) ([]byte, *http.Response, error) {
+func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := client.Do(req)
- duration := time.Since(startTime)
+ duration := time.Now().Sub(startTime)
if res != nil && !dontReadResponse {
defer res.Body.Close()
}
if err != nil {
- // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry
- canRetry := !errors.Is(err, context.Canceled) ||
- errors.Is(context.Cause(req.Context()), ErrContextCancelRetry)
- if retries > 0 && canRetry {
- return cli.doRetry(
- req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
- )
+ if retries > 0 && !errors.Is(err, context.Canceled) {
+ return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client)
}
err = HTTPError{
Request: req,
@@ -777,9 +676,7 @@ func (cli *Client) executeCompiledRequest(
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
- return cli.doRetry(
- req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
- )
+ return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client)
}
var body []byte
@@ -787,7 +684,7 @@ func (cli *Client) executeCompiledRequest(
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
} else {
- body, err = handler(req, res, responseJSON, sizeLimit)
+ body, err = handler(req, res, responseJSON)
cli.LogRequestDone(req, res, nil, err, len(body), duration)
}
return body, res, err
@@ -847,7 +744,7 @@ func (req *ReqSync) BuildQuery() map[string]string {
query["full_state"] = "true"
}
if req.UseStateAfter {
- query["use_state_after"] = "true"
+ query["org.matrix.msc4222.use_state_after"] = "true"
}
if req.BeeperStreaming {
query["com.beeper.streaming"] = "true"
@@ -871,7 +768,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
}
start := time.Now()
_, err = cli.MakeFullRequest(ctx, fullReq)
- duration := time.Since(start)
+ duration := time.Now().Sub(start)
timeout := time.Duration(req.Timeout) * time.Millisecond
buffer := 10 * time.Second
if req.Since == "" {
@@ -918,7 +815,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp
return
}
-func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
+func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
var bodyBytes []byte
bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
@@ -942,7 +839,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[an
// Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
//
// Registers with kind=user. For kind=guest, see RegisterGuest.
-func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
+func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
u := cli.BuildClientURL("v3", "register")
return cli.register(ctx, u, req)
}
@@ -951,7 +848,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRe
// with kind=guest.
//
// For kind=user, see Register.
-func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
+func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
query := map[string]string{
"kind": "guest",
}
@@ -974,8 +871,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*R
// panic(err)
// }
// token := res.AccessToken
-func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) {
- _, uia, err := cli.Register(ctx, req)
+func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
+ res, uia, err := cli.Register(ctx, req)
if err != nil && uia == nil {
return nil, err
} else if uia == nil {
@@ -984,7 +881,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*R
return nil, errors.New("server does not support m.login.dummy")
}
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
- res, _, err := cli.Register(ctx, req)
+ res, _, err = cli.Register(ctx, req)
if err != nil {
return nil, err
}
@@ -1148,19 +1045,8 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs
return
}
-func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) {
- urlPath := cli.BuildClientURL("v3", "user_directory", "search")
- _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{
- SearchTerm: query,
- Limit: limit,
- }, &resp)
- return
-}
-
func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) {
- supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms)
- supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms)
- if cli.SpecVersions != nil && !supportsUnstable && !supportsStable {
+ if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) {
err = fmt.Errorf("server does not support fetching mutual rooms")
return
}
@@ -1170,10 +1056,7 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex
if len(extras) > 0 {
query["from"] = extras[0].From
}
- urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "mutual_rooms"}, query)
- if !supportsStable && supportsUnstable {
- urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
- }
+ urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
@@ -1195,7 +1078,8 @@ func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via
// GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname
func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) {
- err = cli.GetProfileField(ctx, mxid, "displayname", &resp)
+ urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname")
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
@@ -1206,47 +1090,41 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay
// SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname
func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) {
- return cli.SetProfileField(ctx, "displayname", displayName)
+ urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname")
+ s := struct {
+ DisplayName string `json:"displayname"`
+ }{displayName}
+ _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil)
+ return
}
-// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
-func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) {
- urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
- if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
- urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
- }
+// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
+func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) {
+ urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{
key: value,
}, nil)
return
}
-// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
-func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) {
- urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
- if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
- urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
- }
+// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
+func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) {
+ urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return
}
-// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname
-func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) {
- urlPath := cli.BuildClientURL("v3", "profile", userID, key)
- if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
- urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
- }
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into)
- return
-}
-
// GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url
func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) {
+ urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url")
s := struct {
AvatarURL id.ContentURI `json:"avatar_url"`
}{}
- err = cli.GetProfileField(ctx, mxid, "avatar_url", &s)
+
+ _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s)
+ if err != nil {
+ return
+ }
url = s.AvatarURL
return
}
@@ -1338,9 +1216,6 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
- if req.UnstableStickyDuration > 0 {
- queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
- }
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
var isEncrypted bool
@@ -1364,51 +1239,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
return
}
-// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint.
-// contentJSON should be a value that can be encoded as JSON using json.Marshal.
-func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
- var req ReqSendEvent
- if len(extra) > 0 {
- req = extra[0]
- }
-
- var txnID string
- if len(req.TransactionID) > 0 {
- txnID = req.TransactionID
- } else {
- txnID = cli.TxnID()
- }
-
- queryParams := map[string]string{}
- if req.Timestamp > 0 {
- queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
- }
-
- if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted {
- var isEncrypted bool
- isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
- if err != nil {
- err = fmt.Errorf("failed to check if room is encrypted: %w", err)
- return
- }
- if isEncrypted {
- if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
- err = fmt.Errorf("failed to encrypt event: %w", err)
- return
- }
- eventType = event.EventEncrypted
- }
- }
-
- urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID}
- urlPath := cli.BuildURLWithQuery(urlData, queryParams)
- _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
- return
-}
-
-// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
+// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
-func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
+func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
var req ReqSendEvent
if len(extra) > 0 {
req = extra[0]
@@ -1418,18 +1251,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
if req.MeowEventID != "" {
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
}
- if req.TransactionID != "" {
- queryParams["fi.mau.transaction_id"] = req.TransactionID
- }
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
- if req.UnstableStickyDuration > 0 {
- queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
- }
- if req.Timestamp > 0 {
- queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
- }
urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
@@ -1442,38 +1266,14 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
// SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
-//
-// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead.
func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
- resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{
- Timestamp: ts,
+ urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
+ "ts": strconv.FormatInt(ts, 10),
})
- return
-}
-
-func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) {
- query := map[string]string{}
- if req.DelayID != "" {
- query["delay_id"] = string(req.DelayID)
+ _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
+ if err == nil && cli.StateStore != nil {
+ cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
}
- if req.Status != "" {
- query["status"] = string(req.Status)
- }
- if req.NextBatch != "" {
- query["next_batch"] = req.NextBatch
- }
-
- urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query)
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp)
-
- // Migration: merge old keys with new ones
- if resp != nil {
- resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...)
- resp.DelayedEvents = nil
- resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...)
- resp.FinalisedEvents = nil
- }
-
return
}
@@ -1564,10 +1364,6 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re
Msg("Failed to update creator membership in state store after creating room")
}
for _, evt := range req.InitialState {
- evt.RoomID = resp.RoomID
- if evt.StateKey == nil {
- evt.StateKey = ptr.Ptr("")
- }
UpdateStateStore(ctx, cli.StateStore, evt)
}
inviteMembership := event.MembershipInvite
@@ -1582,6 +1378,9 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re
Msg("Failed to update membership in state store after creating room")
}
}
+ for _, evt := range req.InitialState {
+ cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
+ }
}
return
}
@@ -1752,34 +1551,22 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy
"format": "event",
})
_, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt)
+ if err == nil && cli.StateStore != nil {
+ UpdateStateStore(ctx, cli.StateStore, evt)
+ }
if evt != nil {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
- if evt.RoomID == "" {
- evt.RoomID = roomID
- }
- }
- if err == nil && cli.StateStore != nil {
- UpdateStateStore(ctx, cli.StateStore, evt)
}
return
}
// parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map.
-func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
- if res.ContentLength > limit {
- return nil, HTTPError{
- Request: req,
- Response: res,
-
- Message: "not reading response",
- WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
- }
- }
+func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
response := make(RoomStateMap)
responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event)
*responsePtr = response
- dec := json.NewDecoder(io.LimitReader(res.Body, limit))
+ dec := json.NewDecoder(res.Body)
arrayStart, err := dec.Token()
if err != nil {
@@ -1813,8 +1600,6 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any
return nil, nil
}
-type RoomStateMap = map[event.Type]map[string]*event.Event
-
// State gets all state in a room.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {
@@ -1824,21 +1609,12 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
ResponseJSON: &stateMap,
Handler: parseRoomStateArray,
})
- if stateMap != nil {
- pls, ok := stateMap[event.StatePowerLevels][""]
- if ok {
- pls.Content.AsPowerLevels().CreateEvent = stateMap[event.StateCreate][""]
- }
- }
if err == nil && cli.StateStore != nil {
for evtType, evts := range stateMap {
if evtType == event.StateMember {
continue
}
for _, evt := range evts {
- if evt.RoomID == "" {
- evt.RoomID = roomID
- }
UpdateStateStore(ctx, cli.StateStore, evt)
}
}
@@ -1897,9 +1673,6 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa
}
func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
- if mxcURL.IsEmpty() {
- return nil, fmt.Errorf("empty mxc uri provided to Download")
- }
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
Method: http.MethodGet,
URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID),
@@ -1908,41 +1681,6 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re
return resp, err
}
-type DownloadThumbnailExtra struct {
- Method string
- Animated bool
-}
-
-func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) {
- if mxcURL.IsEmpty() {
- return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail")
- }
- if len(extras) > 1 {
- panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras)))
- }
- var extra DownloadThumbnailExtra
- if len(extras) == 1 {
- extra = extras[0]
- }
- path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID}
- query := map[string]string{
- "height": strconv.Itoa(height),
- "width": strconv.Itoa(width),
- }
- if extra.Method != "" {
- query["method"] = extra.Method
- }
- if extra.Animated {
- query["animated"] = "true"
- }
- _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
- Method: http.MethodGet,
- URL: cli.BuildURLWithQuery(path, query),
- DontReadResponse: true,
- })
- return resp, err
-}
-
func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) {
resp, err := cli.Download(ctx, mxcURL)
if err != nil {
@@ -1989,15 +1727,10 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr
}
req.MXC = resp.ContentURI
req.UnstableUploadURL = resp.UnstableUploadURL
- if req.AsyncContext == nil {
- req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background())
- }
go func() {
- _, err = cli.UploadMedia(req.AsyncContext, req)
+ _, err = cli.UploadMedia(ctx, req)
if err != nil {
- zerolog.Ctx(req.AsyncContext).Err(err).
- Stringer("mxc", req.MXC).
- Msg("Async upload of media failed")
+ cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed")
}
}()
return resp, nil
@@ -2033,7 +1766,6 @@ type ReqUploadMedia struct {
ContentType string
FileName string
- AsyncContext context.Context
DoneCallback func()
// MXC specifies an existing MXC URI which doesn't have content yet to upload into.
@@ -2046,19 +1778,14 @@ type ReqUploadMedia struct {
}
func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) {
- cli.Log.Debug().
- Str("url", url).
- Int64("content_length", contentLength).
- Msg("Uploading media to external URL")
+ cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
req.Header.Set("Content-Type", contentType)
- if cli.UserAgent != "" {
- req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)")
- }
+ req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)")
if cli.ExternalClient != nil {
return cli.ExternalClient.Do(req)
@@ -2098,16 +1825,8 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*
Msg("Error uploading media to external URL, not retrying")
return nil, err
}
- backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries)
- cli.Log.Warn().Err(err).
- Str("url", data.UnstableUploadURL).
- Int("retry_in_seconds", int(backoff.Seconds())).
+ cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err).
Msg("Error uploading media to external URL, retrying")
- select {
- case <-time.After(backoff):
- case <-ctx.Done():
- return nil, ctx.Err()
- }
retries--
_, err = readerSeeker.Seek(0, io.SeekStart)
if err != nil {
@@ -2687,15 +2406,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req
return err
}
-func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error {
+func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error {
urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err
}
-func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error {
+func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error {
urlPath := cli.BuildClientURL("v3", "delete_devices")
- _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil)
+ _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err
}
@@ -2704,7 +2423,7 @@ type UIACallback = func(*RespUserInteractive) interface{}
// UploadCrossSigningKeys uploads the given cross-signing keys to the server.
// Because the endpoint requires user-interactive authentication a callback must be provided that,
// given the UI auth parameters, produces the required result (or nil to end the flow).
-func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error {
+func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error {
content, err := cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"),
@@ -2786,61 +2505,24 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri
return err
}
-// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin.
+// BatchSend sends a batch of historical events into a room. This is only available for appservices.
//
-// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
-func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) {
- urlPath := cli.BuildClientURL("v3", "admin", "whois", userID)
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
- return
-}
-
-func (cli *Client) makeMSC4323URL(action string, target id.UserID) string {
- if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) {
- return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target)
- } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) {
- return cli.BuildClientURL("v1", "admin", action, target)
+// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead.
+func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) {
+ path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"}
+ query := map[string]string{
+ "prev_event_id": req.PrevEventID.String(),
}
- return ""
-}
-
-// GetSuspendedStatus uses MSC4323 to check if a user is suspended.
-func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) {
- urlPath := cli.makeMSC4323URL("suspend", userID)
- if urlPath == "" {
- return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ if req.BeeperNewMessages {
+ query["com.beeper.new_messages"] = "true"
}
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
- return
-}
-
-// GetLockStatus uses MSC4323 to check if a user is locked.
-func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) {
- urlPath := cli.makeMSC4323URL("lock", userID)
- if urlPath == "" {
- return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ if req.BeeperMarkReadBy != "" {
+ query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String()
}
- _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
- return
-}
-
-// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended.
-func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) {
- urlPath := cli.makeMSC4323URL("suspend", userID)
- if urlPath == "" {
- return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ if len(req.BatchID) > 0 {
+ query["batch_id"] = req.BatchID.String()
}
- _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res)
- return
-}
-
-// SetLockStatus uses MSC4323 to set whether a user account is locked.
-func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) {
- urlPath := cli.makeMSC4323URL("lock", userID)
- if urlPath == "" {
- return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
- }
- _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res)
+ _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp)
return
}
diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go
deleted file mode 100644
index c2846427..00000000
--- a/client_ephemeral_test.go
+++ /dev/null
@@ -1,158 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package mautrix_test
-
-import (
- "context"
- "encoding/json"
- "errors"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) {
- roomID := id.RoomID("!room:example.com")
- evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
- txnID := "txn-123"
-
- var gotPath string
- var gotQueryTS string
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- gotPath = r.URL.Path
- gotQueryTS = r.URL.Query().Get("ts")
- assert.Equal(t, http.MethodPut, r.Method)
- _, _ = w.Write([]byte(`{"event_id":"$evt"}`))
- }))
- defer ts.Close()
-
- cli, err := mautrix.NewClient(ts.URL, "", "")
- require.NoError(t, err)
-
- _, err = cli.BeeperSendEphemeralEvent(
- context.Background(),
- roomID,
- evtType,
- map[string]any{"foo": "bar"},
- mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234},
- )
- require.NoError(t, err)
-
- assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/"))
- assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID))
- assert.Equal(t, "1234", gotQueryTS)
-}
-
-func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) {
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusNotFound)
- _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`))
- }))
- defer ts.Close()
-
- cli, err := mautrix.NewClient(ts.URL, "", "")
- require.NoError(t, err)
-
- _, err = cli.BeeperSendEphemeralEvent(
- context.Background(),
- id.RoomID("!room:example.com"),
- event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType},
- map[string]any{"foo": "bar"},
- )
- require.Error(t, err)
- assert.True(t, errors.Is(err, mautrix.MUnrecognized))
-}
-
-func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) {
- roomID := id.RoomID("!room:example.com")
- evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
- txnID := "txn-encrypted"
-
- stateStore := mautrix.NewMemoryStateStore()
- err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{
- Algorithm: id.AlgorithmMegolmV1,
- })
- require.NoError(t, err)
-
- fakeCrypto := &fakeCryptoHelper{
- encryptedContent: &event.EncryptedEventContent{
- Algorithm: id.AlgorithmMegolmV1,
- MegolmCiphertext: []byte("ciphertext"),
- },
- }
-
- var gotPath string
- var gotBody map[string]any
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- gotPath = r.URL.Path
- assert.Equal(t, http.MethodPut, r.Method)
- err := json.NewDecoder(r.Body).Decode(&gotBody)
- require.NoError(t, err)
- _, _ = w.Write([]byte(`{"event_id":"$evt"}`))
- }))
- defer ts.Close()
-
- cli, err := mautrix.NewClient(ts.URL, "", "")
- require.NoError(t, err)
- cli.StateStore = stateStore
- cli.Crypto = fakeCrypto
-
- _, err = cli.BeeperSendEphemeralEvent(
- context.Background(),
- roomID,
- evtType,
- map[string]any{"foo": "bar"},
- mautrix.ReqSendEvent{TransactionID: txnID},
- )
- require.NoError(t, err)
-
- assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID))
- assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"])
- assert.Equal(t, 1, fakeCrypto.encryptCalls)
- assert.Equal(t, roomID, fakeCrypto.lastRoomID)
- assert.Equal(t, evtType, fakeCrypto.lastEventType)
-}
-
-type fakeCryptoHelper struct {
- encryptCalls int
- lastRoomID id.RoomID
- lastEventType event.Type
- lastEncryptInput any
- encryptedContent *event.EncryptedEventContent
-}
-
-func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) {
- f.encryptCalls++
- f.lastRoomID = roomID
- f.lastEventType = eventType
- f.lastEncryptInput = content
- return f.encryptedContent, nil
-}
-
-func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) {
- return nil, nil
-}
-
-func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool {
- return false
-}
-
-func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) {
-}
-
-func (f *fakeCryptoHelper) Init(context.Context) error {
- return nil
-}
diff --git a/commands/container.go b/commands/container.go
index 9b909b75..bc685b7b 100644
--- a/commands/container.go
+++ b/commands/container.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2026 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,20 +8,14 @@ package commands
import (
"fmt"
- "slices"
"strings"
"sync"
-
- "go.mau.fi/util/exmaps"
-
- "maunium.net/go/mautrix/event/cmdschema"
)
type CommandContainer[MetaType any] struct {
commands map[string]*Handler[MetaType]
aliases map[string]string
lock sync.RWMutex
- parent *Handler[MetaType]
}
func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
@@ -31,29 +25,6 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
}
}
-func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent {
- data := make(exmaps.Set[*Handler[MetaType]])
- cont.collectHandlers(data)
- specs := make([]*cmdschema.EventContent, 0, data.Size())
- for handler := range data.Iter() {
- if handler.Parameters != nil {
- specs = append(specs, handler.Spec())
- }
- }
- return specs
-}
-
-func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) {
- cont.lock.RLock()
- defer cont.lock.RUnlock()
- for _, handler := range cont.commands {
- into.Add(handler)
- if handler.subcommandContainer != nil {
- handler.subcommandContainer.collectHandlers(into)
- }
- }
-}
-
// Register registers the given command handlers.
func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) {
if cont == nil {
@@ -61,10 +32,7 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType])
}
cont.lock.Lock()
defer cont.lock.Unlock()
- for i, handler := range handlers {
- if handler == nil {
- panic(fmt.Errorf("handler #%d is nil", i+1))
- }
+ for _, handler := range handlers {
cont.registerOne(handler)
}
}
@@ -77,10 +45,6 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType])
} else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists {
panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget))
}
- if !slices.Contains(handler.parents, cont.parent) {
- handler.parents = append(handler.parents, cont.parent)
- handler.nestedNameCache = nil
- }
cont.commands[handler.Name] = handler
for _, alias := range handler.Aliases {
if strings.ToLower(alias) != alias {
diff --git a/commands/event.go b/commands/event.go
index 76d6c9f0..77a3c0d2 100644
--- a/commands/event.go
+++ b/commands/event.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2026 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,7 +8,6 @@ package commands
import (
"context"
- "encoding/json"
"fmt"
"strings"
@@ -36,8 +35,6 @@ type Event[MetaType any] struct {
// RawArgs is the same as args, but without the splitting by whitespace.
RawArgs string
- StructuredArgs json.RawMessage
-
Ctx context.Context
Log *zerolog.Logger
Proc *Processor[MetaType]
@@ -64,7 +61,7 @@ var IDHTMLParser = &format.HTMLParser{
}
// ParseEvent parses a message into a command event struct.
-func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] {
+func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" {
return nil
@@ -73,34 +70,12 @@ func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Even
if content.Format == event.FormatHTML {
text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx))
}
- if content.MSC4391BotCommand != nil {
- if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 {
- return nil
- }
- wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand)
- wrapped.RawInput = text
- return wrapped
- }
if len(text) == 0 {
return nil
}
return RawTextToEvent[MetaType](ctx, evt, text)
}
-func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] {
- commandParts := strings.Split(content.Command, " ")
- return &Event[MetaType]{
- Event: evt,
- // Fake a command and args to let the subcommand finder in Process work.
- Command: commandParts[0],
- Args: commandParts[1:],
- Ctx: ctx,
- Log: zerolog.Ctx(ctx),
-
- StructuredArgs: content.Arguments,
- }
-}
-
func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] {
parts := strings.Fields(text)
if len(parts) == 0 {
@@ -213,25 +188,3 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) {
evt.RawArgs = arg + " " + evt.RawArgs
evt.Args = append([]string{arg}, evt.Args...)
}
-
-func (evt *Event[MetaType]) ParseArgs(into any) error {
- return json.Unmarshal(evt.StructuredArgs, into)
-}
-
-func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) {
- err = evt.ParseArgs(&into)
- return
-}
-
-func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) {
- return func(evt *Event[MetaType]) {
- parsed, err := ParseArgs[T, MetaType](evt)
- if err != nil {
- evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct")
- // TODO better error, usage info? deduplicate with Process
- evt.Reply("Failed to parse arguments: %v", err)
- return
- }
- fn(evt, parsed)
- }
-}
diff --git a/commands/handler.go b/commands/handler.go
index 56f27f06..b01d594f 100644
--- a/commands/handler.go
+++ b/commands/handler.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2026 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,9 +8,6 @@ package commands
import (
"strings"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/event/cmdschema"
)
type Handler[MetaType any] struct {
@@ -28,63 +25,12 @@ type Handler[MetaType any] struct {
// Event.ShiftArg will likely be useful for implementing such parameters.
PreFunc func(ce *Event[MetaType])
- // Description is a short description of the command.
- Description *event.ExtensibleTextContainer
- // Parameters is a description of structured command parameters.
- // If set, the StructuredArgs field of Event will be populated.
- Parameters []*cmdschema.Parameter
- TailParam string
-
- parents []*Handler[MetaType]
- nestedNameCache []string
subcommandContainer *CommandContainer[MetaType]
}
-func (h *Handler[MetaType]) NestedNames() []string {
- if h.nestedNameCache != nil {
- return h.nestedNameCache
- }
- nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents))
- for _, parent := range h.parents {
- if parent == nil {
- nestedNames = append(nestedNames, h.Name)
- nestedNames = append(nestedNames, h.Aliases...)
- } else {
- for _, parentName := range parent.NestedNames() {
- nestedNames = append(nestedNames, parentName+" "+h.Name)
- for _, alias := range h.Aliases {
- nestedNames = append(nestedNames, parentName+" "+alias)
- }
- }
- }
- }
- h.nestedNameCache = nestedNames
- return nestedNames
-}
-
-func (h *Handler[MetaType]) Spec() *cmdschema.EventContent {
- names := h.NestedNames()
- return &cmdschema.EventContent{
- Command: names[0],
- Aliases: names[1:],
- Parameters: h.Parameters,
- Description: h.Description,
- TailParam: h.TailParam,
- }
-}
-
-func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) {
- if h.Parameters == nil {
- h.Parameters = other.Parameters
- h.TailParam = other.TailParam
- }
- h.Func = other.Func
-}
-
func (h *Handler[MetaType]) initSubcommandContainer() {
if len(h.Subcommands) > 0 {
h.subcommandContainer = NewCommandContainer[MetaType]()
- h.subcommandContainer.parent = h
h.subcommandContainer.Register(h.Subcommands...)
} else {
h.subcommandContainer = nil
diff --git a/commands/processor.go b/commands/processor.go
index 80f6745d..9341329b 100644
--- a/commands/processor.go
+++ b/commands/processor.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2026 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
case event.EventReaction:
parsed = proc.ParseReaction(ctx, evt)
case event.EventMessage:
- parsed = proc.ParseEvent(ctx, evt)
+ parsed = ParseEvent[MetaType](ctx, evt)
}
- if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) {
+ if parsed == nil || !proc.PreValidator.Validate(parsed) {
return
}
parsed.Proc = proc
@@ -107,12 +107,6 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
break
}
}
- if parsed.StructuredArgs != nil && len(parsed.Args) > 0 {
- // TODO allow unknown command handlers to be called?
- // The client sent MSC4391 data, but the target command wasn't found
- log.Debug().Msg("Didn't find handler for MSC4391 command")
- return
- }
logWith := log.With().
Str("command", parsed.Command).
@@ -122,31 +116,11 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
}
if proc.LogArgs {
logWith = logWith.Strs("args", parsed.Args)
- if parsed.StructuredArgs != nil {
- logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs)
- }
}
log = logWith.Logger()
parsed.Ctx = log.WithContext(ctx)
parsed.Log = &log
- if handler.Parameters != nil && parsed.StructuredArgs == nil {
- // The handler wants structured parameters, but the client didn't send MSC4391 data
- var err error
- parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs)
- if err != nil {
- log.Debug().Err(err).Msg("Failed to parse structured arguments")
- // TODO better error, usage info? deduplicate with WithParsedArgs
- parsed.Reply("Failed to parse arguments: %v", err)
- return
- }
- if proc.LogArgs {
- log.UpdateContext(func(c zerolog.Context) zerolog.Context {
- return c.RawJSON("structured_args", parsed.StructuredArgs)
- })
- }
- }
-
log.Debug().Msg("Processing command")
handler.Func(parsed)
}
diff --git a/commands/reactions.go b/commands/reactions.go
index 0d316219..0df372e5 100644
--- a/commands/reactions.go
+++ b/commands/reactions.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2026 Tulir Asokan
+// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,7 +8,6 @@ package commands
import (
"context"
- "encoding/json"
"strings"
"github.com/rs/zerolog"
@@ -20,11 +19,6 @@ import (
const ReactionCommandsKey = "fi.mau.reaction_commands"
const ReactionMultiUseKey = "fi.mau.reaction_multi_use"
-type ReactionCommandData struct {
- Command string `json:"command"`
- Args any `json:"args,omitempty"`
-}
-
func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.ReactionEventContent)
if !ok {
@@ -73,33 +67,21 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E
Msg("Reaction command not found in target event")
return nil
}
- var wrappedEvt *Event[MetaType]
- switch typedCmd := rawCmd.(type) {
- case string:
- wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd)
- case map[string]any:
- var input event.MSC4391BotCommandInput
- if marshaled, err := json.Marshal(typedCmd); err != nil {
-
- } else if err = json.Unmarshal(marshaled, &input); err != nil {
-
- } else {
- wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input)
- }
- }
- if wrappedEvt == nil {
+ cmdString, ok := rawCmd.(string)
+ if !ok {
zerolog.Ctx(ctx).Debug().
Stringer("target_event_id", evtID).
Str("reaction_key", content.RelatesTo.Key).
Msg("Reaction command data is invalid")
return nil
}
+ wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString)
wrappedEvt.Proc = proc
wrappedEvt.Redact()
if !isMultiUse {
DeleteAllReactions(ctx, proc.Client, evt)
}
- if wrappedEvt.Command == "" {
+ if cmdString == "" {
return nil
}
return wrappedEvt
diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go
index d6611dc9..bb03f706 100644
--- a/crypto/aescbc/aes_cbc_test.go
+++ b/crypto/aescbc/aes_cbc_test.go
@@ -7,13 +7,11 @@
package aescbc_test
import (
+ "bytes"
"crypto/aes"
"crypto/rand"
"testing"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
"maunium.net/go/mautrix/crypto/aescbc"
)
@@ -24,23 +22,32 @@ func TestAESCBC(t *testing.T) {
// The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256)
key := make([]byte, 32)
_, err = rand.Read(key)
- require.NoError(t, err)
+ if err != nil {
+ t.Fatal(err)
+ }
iv := make([]byte, aes.BlockSize)
_, err = rand.Read(iv)
- require.NoError(t, err)
+ if err != nil {
+ t.Fatal(err)
+ }
plaintext = []byte("secret message for testing")
//increase to next block size
for len(plaintext)%8 != 0 {
plaintext = append(plaintext, []byte("-")...)
}
- ciphertext, err = aescbc.Encrypt(key, iv, plaintext)
- require.NoError(t, err)
+ if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil {
+ t.Fatal(err)
+ }
resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext)
- require.NoError(t, err)
+ if err != nil {
+ t.Fatal(err)
+ }
- assert.Equal(t, string(resultPlainText), string(plaintext))
+ if string(resultPlainText) != string(plaintext) {
+ t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext)
+ }
}
func TestAESCBCCase1(t *testing.T) {
@@ -54,10 +61,18 @@ func TestAESCBCCase1(t *testing.T) {
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
encrypted, err := aescbc.Encrypt(key, iv, input)
- require.NoError(t, err)
- assert.Equal(t, expected, encrypted, "encrypted output does not match expected")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(expected, encrypted) {
+ t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected)
+ }
decrypted, err := aescbc.Decrypt(key, iv, encrypted)
- require.NoError(t, err)
- assert.Equal(t, input, decrypted, "decrypted output does not match input")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(input, decrypted) {
+ t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input)
+ }
}
diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go
index 727aacbf..cfa1c3e5 100644
--- a/crypto/attachment/attachments.go
+++ b/crypto/attachment/attachments.go
@@ -9,7 +9,6 @@ package attachment
import (
"crypto/aes"
"crypto/cipher"
- "crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
@@ -21,24 +20,13 @@ import (
)
var (
- ErrHashMismatch = errors.New("mismatching SHA-256 digest")
- ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
- ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
- ErrInvalidKey = errors.New("failed to decode key")
- ErrInvalidInitVector = errors.New("failed to decode initialization vector")
- ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
- ErrReaderClosed = errors.New("encrypting reader was already closed")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- HashMismatch = ErrHashMismatch
- UnsupportedVersion = ErrUnsupportedVersion
- UnsupportedAlgorithm = ErrUnsupportedAlgorithm
- InvalidKey = ErrInvalidKey
- InvalidInitVector = ErrInvalidInitVector
- InvalidHash = ErrInvalidHash
- ReaderClosed = ErrReaderClosed
+ HashMismatch = errors.New("mismatching SHA-256 digest")
+ UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
+ UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
+ InvalidKey = errors.New("failed to decode key")
+ InvalidInitVector = errors.New("failed to decode initialization vector")
+ InvalidHash = errors.New("failed to decode SHA-256 hash")
+ ReaderClosed = errors.New("encrypting reader was already closed")
)
var (
@@ -96,25 +84,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
if ef.decoded != nil {
return nil
} else if len(ef.Key.Key) != keyBase64Length {
- return ErrInvalidKey
+ return InvalidKey
} else if len(ef.InitVector) != ivBase64Length {
- return ErrInvalidInitVector
+ return InvalidInitVector
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
- return ErrInvalidHash
+ return InvalidHash
}
ef.decoded = &decodedKeys{}
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
if err != nil {
- return ErrInvalidKey
+ return InvalidKey
}
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
if err != nil {
- return ErrInvalidInitVector
+ return InvalidInitVector
}
if includeHash {
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
if err != nil {
- return ErrInvalidHash
+ return InvalidHash
}
}
return nil
@@ -190,7 +178,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
if r.closed {
- return 0, ErrReaderClosed
+ return 0, ReaderClosed
}
if offset != 0 || whence != io.SeekStart {
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
@@ -211,20 +199,15 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
if r.closed {
- return 0, ErrReaderClosed
+ return 0, ReaderClosed
} else if r.isDecrypting && r.file.decoded == nil {
if err = r.file.PrepareForDecryption(); err != nil {
return
}
}
n, err = r.source.Read(dst)
- if r.isDecrypting {
- r.hash.Write(dst[:n])
- }
r.stream.XORKeyStream(dst[:n], dst[:n])
- if !r.isDecrypting {
- r.hash.Write(dst[:n])
- }
+ r.hash.Write(dst[:n])
return
}
@@ -234,8 +217,10 @@ func (r *encryptingReader) Close() (err error) {
err = closer.Close()
}
if r.isDecrypting {
- if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
- return ErrHashMismatch
+ var downloadedChecksum [utils.SHAHashLength]byte
+ r.hash.Sum(downloadedChecksum[:])
+ if downloadedChecksum != r.file.decoded.sha256 {
+ return HashMismatch
}
} else {
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
@@ -276,9 +261,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
func (ef *EncryptedFile) PrepareForDecryption() error {
if ef.Version != "v2" {
- return ErrUnsupportedVersion
+ return UnsupportedVersion
} else if ef.Key.Algorithm != "A256CTR" {
- return ErrUnsupportedAlgorithm
+ return UnsupportedAlgorithm
} else if err := ef.decodeKeys(true); err != nil {
return err
}
@@ -289,13 +274,12 @@ func (ef *EncryptedFile) PrepareForDecryption() error {
func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
if err := ef.PrepareForDecryption(); err != nil {
return err
+ } else if ef.decoded.sha256 != sha256.Sum256(data) {
+ return HashMismatch
+ } else {
+ utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
+ return nil
}
- dataHash := sha256.Sum256(data)
- if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
- return ErrHashMismatch
- }
- utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
- return nil
}
// DecryptStream wraps the given io.Reader in order to decrypt the data.
@@ -308,10 +292,9 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser {
block, _ := aes.NewCipher(ef.decoded.key[:])
return &encryptingReader{
- isDecrypting: true,
- stream: cipher.NewCTR(block, ef.decoded.iv[:]),
- hash: sha256.New(),
- source: reader,
- file: ef,
+ stream: cipher.NewCTR(block, ef.decoded.iv[:]),
+ hash: sha256.New(),
+ source: reader,
+ file: ef,
}
}
diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go
index 9fe929ab..d7f1394a 100644
--- a/crypto/attachment/attachments_test.go
+++ b/crypto/attachment/attachments_test.go
@@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
file := parseHelloWorld()
file.Version = "foo"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, ErrUnsupportedVersion)
+ assert.ErrorIs(t, err, UnsupportedVersion)
}
func TestUnsupportedAlgorithm(t *testing.T) {
file := parseHelloWorld()
file.Key.Algorithm = "bar"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
+ assert.ErrorIs(t, err, UnsupportedAlgorithm)
}
func TestHashMismatch(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, ErrHashMismatch)
+ assert.ErrorIs(t, err, HashMismatch)
}
func TestTooLongHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, ErrInvalidHash)
+ assert.ErrorIs(t, err, InvalidHash)
}
func TestTooShortHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "5/Gy1JftyyQ"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, ErrInvalidHash)
+ assert.ErrorIs(t, err, InvalidHash)
}
diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go
index 25250178..ec551dbe 100644
--- a/crypto/backup/encryptedsessiondata.go
+++ b/crypto/backup/encryptedsessiondata.go
@@ -68,10 +68,6 @@ func calculateCompatMAC(macKey []byte) []byte {
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) {
- return EncryptSessionDataWithPubkey(backupKey.PublicKey(), sessionData)
-}
-
-func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) (*EncryptedSessionData[T], error) {
sessionJSON, err := json.Marshal(sessionData)
if err != nil {
return nil, err
@@ -82,7 +78,7 @@ func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T)
return nil, err
}
- sharedSecret, err := ephemeralKey.ECDH(pubkey)
+ sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey())
if err != nil {
return nil, err
}
diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go
index 36476aa4..d1a7f0a5 100644
--- a/crypto/canonicaljson/json_test.go
+++ b/crypto/canonicaljson/json_test.go
@@ -17,43 +17,31 @@ package canonicaljson
import (
"testing"
-
- "github.com/stretchr/testify/assert"
)
-func TestSortJSON(t *testing.T) {
- var tests = []struct {
- input string
- want string
- }{
- {"{}", "{}"},
- {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`},
- {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`},
- {`[true,false,null]`, `[true,false,null]`},
- {`[9007199254740991]`, `[9007199254740991]`},
- {"\t\n[9007199254740991]", `[9007199254740991]`},
- {`[true,false,null]`, `[true,false,null]`},
- {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`},
- {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`},
- {`[true,false,null]`, `[true,false,null]`},
- {`[9007199254740991]`, `[9007199254740991]`},
- {"\t\n[9007199254740991]", `[9007199254740991]`},
- {`[true,false,null]`, `[true,false,null]`},
- }
- for _, test := range tests {
- t.Run(test.input, func(t *testing.T) {
- got := SortJSON([]byte(test.input), nil)
+func testSortJSON(t *testing.T, input, want string) {
+ got := SortJSON([]byte(input), nil)
- // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
- assert.EqualValues(t, test.want, string(CompactJSON(got, nil)))
- })
+ // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace.
+ if string(CompactJSON(got, nil)) != want {
+ t.Errorf("SortJSON(%q): want %q got %q", input, want, got)
}
}
+func TestSortJSON(t *testing.T) {
+ testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`)
+ testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`,
+ `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`)
+ testSortJSON(t, `[true,false,null]`, `[true,false,null]`)
+ testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`)
+ testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`)
+}
+
func testCompactJSON(t *testing.T, input, want string) {
- t.Helper()
got := string(CompactJSON([]byte(input), nil))
- assert.EqualValues(t, want, got)
+ if got != want {
+ t.Errorf("CompactJSON(%q): want %q got %q", input, want, got)
+ }
}
func TestCompactJSON(t *testing.T) {
@@ -86,23 +74,18 @@ func TestCompactJSON(t *testing.T) {
testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`)
}
-func TestReadHex(t *testing.T) {
- tests := []struct {
- input string
- want uint32
- }{
-
- {"0123", 0x0123},
- {"4567", 0x4567},
- {"89AB", 0x89AB},
- {"CDEF", 0xCDEF},
- {"89ab", 0x89AB},
- {"cdef", 0xCDEF},
- }
- for _, test := range tests {
- t.Run(test.input, func(t *testing.T) {
- got := readHexDigits([]byte(test.input))
- assert.Equal(t, test.want, got)
- })
+func testReadHex(t *testing.T, input string, want uint32) {
+ got := readHexDigits([]byte(input))
+ if want != got {
+ t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got)
}
}
+
+func TestReadHex(t *testing.T) {
+ testReadHex(t, "0123", 0x0123)
+ testReadHex(t, "4567", 0x4567)
+ testReadHex(t, "89AB", 0x89AB)
+ testReadHex(t, "CDEF", 0xCDEF)
+ testReadHex(t, "89ab", 0x89AB)
+ testReadHex(t, "cdef", 0xCDEF)
+}
diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go
index 5d9bf5b3..4094f695 100644
--- a/crypto/cross_sign_key.go
+++ b/crypto/cross_sign_key.go
@@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
}
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
- err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{
+ err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
Master: masterKey,
SelfSigning: selfKey,
UserSigning: userKey,
diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go
index 223fc7b5..77efab5b 100644
--- a/crypto/cross_sign_pubkey.go
+++ b/crypto/cross_sign_pubkey.go
@@ -20,20 +20,6 @@ type CrossSigningPublicKeysCache struct {
UserSigningKey id.Ed25519
}
-func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) {
- pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
- if pubkeys != nil {
- hasKeys = true
- isVerified, err = mach.CryptoStore.IsKeySignedBy(
- ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey,
- )
- if err != nil {
- err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
- }
- }
- return
-}
-
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache {
if mach.crossSigningPubkeys != nil {
return mach.crossSigningPubkeys
@@ -63,8 +49,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
if len(dbKeys) > 0 {
masterKey, ok := dbKeys[id.XSUsageMaster]
if ok {
- selfSigning := dbKeys[id.XSUsageSelfSigning]
- userSigning := dbKeys[id.XSUsageUserSigning]
+ selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
+ userSigning, _ := dbKeys[id.XSUsageUserSigning]
return &CrossSigningPublicKeysCache{
MasterKey: masterKey.Key,
SelfSigningKey: selfSigning.Key,
diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go
index fd42880d..389a9fd2 100644
--- a/crypto/cross_sign_ssss.go
+++ b/crypto/cross_sign_ssss.go
@@ -8,7 +8,6 @@ package crypto
import (
"context"
- "errors"
"fmt"
"maunium.net/go/mautrix"
@@ -72,46 +71,6 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex
}, passphrase)
}
-func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error {
- keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx)
- if err != nil {
- return fmt.Errorf("failed to get default SSSS key data: %w", err)
- }
- key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey)
- if errors.Is(err, ssss.ErrUnverifiableKey) {
- mach.machOrContextLog(ctx).Warn().
- Str("key_id", keyID).
- Msg("SSSS key is unverifiable, trying to use without verifying")
- } else if err != nil {
- return err
- }
- err = mach.FetchCrossSigningKeysFromSSSS(ctx, key)
- if err != nil {
- return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
- }
- err = mach.SignOwnDevice(ctx, mach.OwnIdentity())
- if err != nil {
- return fmt.Errorf("failed to sign own device: %w", err)
- }
- err = mach.SignOwnMasterKey(ctx)
- if err != nil {
- return fmt.Errorf("failed to sign own master key: %w", err)
- }
- return nil
-}
-
-func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) {
- recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
- if err != nil {
- err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err)
- } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil {
- err = fmt.Errorf("failed to sign own device: %w", err)
- } else if err = mach.SignOwnMasterKey(ctx); err != nil {
- err = fmt.Errorf("failed to sign own master key: %w", err)
- }
- return
-}
-
// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys.
//
// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key
@@ -138,12 +97,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
// Publish cross-signing keys
err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback)
if err != nil {
- return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err)
+ return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err)
}
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
if err != nil {
- return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
+ return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
}
return key.RecoveryKey(), keysCache, nil
diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go
index 57406b11..b583bada 100644
--- a/crypto/cross_sign_store.go
+++ b/crypto/cross_sign_store.go
@@ -20,34 +20,36 @@ import (
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
- log := log.With().Stringer("user_id", userID).Logger()
+ log := log.With().Str("user_id", userID.String()).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
- for curKeyUsage, curKey := range currentKeys {
- log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
- // got a new key with the same usage as an existing key
- for _, newKeyUsage := range userKeys.Usage {
- if newKeyUsage == curKeyUsage {
- if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
- // old key is not in the new key map, so we drop signatures made by it
- if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
- log.Error().Err(err).Msg("Error deleting old signatures made by user")
- } else {
- log.Debug().
- Int64("signature_count", count).
- Msg("Dropped signatures made by old key as it has been replaced")
+ if currentKeys != nil {
+ for curKeyUsage, curKey := range currentKeys {
+ log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
+ // got a new key with the same usage as an existing key
+ for _, newKeyUsage := range userKeys.Usage {
+ if newKeyUsage == curKeyUsage {
+ if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
+ // old key is not in the new key map, so we drop signatures made by it
+ if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
+ log.Error().Err(err).Msg("Error deleting old signatures made by user")
+ } else {
+ log.Debug().
+ Int64("signature_count", count).
+ Msg("Dropped signatures made by old key as it has been replaced")
+ }
}
+ break
}
- break
}
}
}
for _, key := range userKeys.Keys {
- log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
+ log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go
index b70370a2..5e1ffd50 100644
--- a/crypto/cross_sign_test.go
+++ b/crypto/cross_sign_test.go
@@ -13,8 +13,6 @@ import (
"testing"
"github.com/rs/zerolog"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
@@ -26,12 +24,17 @@ var noopLogger = zerolog.Nop()
func getOlmMachine(t *testing.T) *OlmMachine {
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
- require.NoError(t, err, "Error opening raw database")
+ if err != nil {
+ t.Fatalf("Error opening db: %v", err)
+ }
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
- require.NoError(t, err, "Error creating database wrapper")
+ if err != nil {
+ t.Fatalf("Error opening db: %v", err)
+ }
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
- err = sqlStore.DB.Upgrade(context.TODO())
- require.NoError(t, err, "Error upgrading database")
+ if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
+ t.Fatalf("Error creating tables: %v", err)
+ }
userID := id.UserID("@mautrix")
mk, _ := olm.NewPKSigning()
@@ -63,25 +66,29 @@ func TestTrustOwnDevice(t *testing.T) {
DeviceID: "device",
SigningKey: id.Ed25519("deviceKey"),
}
- assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be")
+ if m.IsDeviceTrusted(context.TODO(), ownDevice) {
+ t.Error("Own device trusted while it shouldn't be")
+ }
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(),
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2")
- trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID)
- require.NoError(t, err, "Error checking if own user is trusted")
- assert.True(t, trusted, "Own user not trusted while they should be")
- assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be")
+ if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
+ t.Error("Own user not trusted while they should be")
+ }
+ if !m.IsDeviceTrusted(context.TODO(), ownDevice) {
+ t.Error("Own device not trusted while it should be")
+ }
}
func TestTrustOtherUser(t *testing.T) {
m := getOlmMachine(t)
otherUser := id.UserID("@user")
- trusted, err := m.IsUserTrusted(context.TODO(), otherUser)
- require.NoError(t, err, "Error checking if other user is trusted")
- assert.False(t, trusted, "Other user trusted while they shouldn't be")
+ if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
+ t.Error("Other user trusted while they shouldn't be")
+ }
theirMasterKey, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
@@ -93,16 +100,16 @@ func TestTrustOtherUser(t *testing.T) {
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig")
- trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
- require.NoError(t, err, "Error checking if other user is trusted")
- assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key")
+ if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
+ t.Error("Other user trusted before their master key has been signed with our user-signing key")
+ }
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
- trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
- require.NoError(t, err, "Error checking if other user is trusted")
- assert.True(t, trusted, "Other user not trusted while they should be")
+ if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
+ t.Error("Other user not trusted while they should be")
+ }
}
func TestTrustOtherDevice(t *testing.T) {
@@ -113,11 +120,12 @@ func TestTrustOtherDevice(t *testing.T) {
DeviceID: "theirDevice",
SigningKey: id.Ed25519("theirDeviceKey"),
}
-
- trusted, err := m.IsUserTrusted(context.TODO(), otherUser)
- require.NoError(t, err, "Error checking if other user is trusted")
- assert.False(t, trusted, "Other user trusted while they shouldn't be")
- assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be")
+ if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
+ t.Error("Other user trusted while they shouldn't be")
+ }
+ if m.IsDeviceTrusted(context.TODO(), theirDevice) {
+ t.Error("Other device trusted while it shouldn't be")
+ }
theirMasterKey, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
@@ -129,17 +137,21 @@ func TestTrustOtherDevice(t *testing.T) {
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
- trusted, err = m.IsUserTrusted(context.TODO(), otherUser)
- require.NoError(t, err, "Error checking if other user is trusted")
- assert.True(t, trusted, "Other user not trusted while they should be")
+ if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
+ t.Error("Other user not trusted while they should be")
+ }
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(),
otherUser, theirMasterKey.PublicKey(), "sig3")
- assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK")
+ if m.IsDeviceTrusted(context.TODO(), theirDevice) {
+ t.Error("Other device trusted before it has been signed with user's SSK")
+ }
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
otherUser, theirSSK.PublicKey(), "sig4")
- assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK")
+ if !m.IsDeviceTrusted(context.TODO(), theirDevice) {
+ t.Error("Other device not trusted while it should be")
+ }
}
diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go
index b62dc128..56f8b484 100644
--- a/crypto/cryptohelper/cryptohelper.go
+++ b/crypto/cryptohelper/cryptohelper.go
@@ -225,6 +225,13 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted)
}
+ if helper.client.SetAppServiceDeviceID {
+ err = helper.mach.ShareKeys(ctx, -1)
+ if err != nil {
+ return fmt.Errorf("failed to share keys: %w", err)
+ }
+ }
+
return nil
}
@@ -261,24 +268,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
if !ok || len(device.Keys) == 0 {
if isShared {
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
+ } else {
+ helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
+ return nil
}
- helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys")
- err = helper.mach.ShareKeys(ctx, -1)
- if err != nil {
- return fmt.Errorf("failed to share keys: %w", err)
- }
- return nil
} else if !isShared {
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
+ }
+ if !isShared {
+ helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
} else {
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
- return nil
}
+ return nil
}
-var NoSessionFound = crypto.ErrNoSessionFound
+var NoSessionFound = crypto.NoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
@@ -297,14 +304,24 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even
ctx = log.WithContext(ctx)
decrypted, err := helper.Decrypt(ctx, evt)
- if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" {
- go helper.waitForSession(ctx, evt)
- } else if err != nil {
+ if errors.Is(err, NoSessionFound) {
+ log.Debug().
+ Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
+ Msg("Couldn't find session, waiting for keys to arrive...")
+ if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
+ log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
+ decrypted, err = helper.Decrypt(ctx, evt)
+ } else {
+ go helper.waitLongerForSession(ctx, log, evt)
+ return
+ }
+ }
+ if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
- } else {
- helper.postDecrypt(ctx, decrypted)
+ return
}
+ helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
@@ -345,33 +362,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
}
}
-func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) {
- log := zerolog.Ctx(ctx)
- content := evt.Content.AsEncrypted()
-
- log.Debug().
- Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
- Msg("Couldn't find session, waiting for keys to arrive...")
- if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
- log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
- decrypted, err := helper.Decrypt(ctx, evt)
- if err != nil {
- log.Warn().Err(err).Msg("Failed to decrypt event")
- helper.DecryptErrorCallback(evt, err)
- } else {
- helper.postDecrypt(ctx, decrypted)
- }
- } else {
- go helper.waitLongerForSession(ctx, evt)
- }
-}
-
-func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) {
- log := zerolog.Ctx(ctx)
+func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
- //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@@ -419,7 +413,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
defer helper.lock.RUnlock()
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
if err != nil {
- if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
+ if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
return
}
helper.log.Debug().
diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go
index 457d5a0c..47279474 100644
--- a/crypto/decryptmegolm.go
+++ b/crypto/decryptmegolm.go
@@ -24,23 +24,13 @@ import (
)
var (
- ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
- ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
- ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
- ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
- ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
- ErrRatchetError = errors.New("failed to ratchet session after use")
- ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
- NoSessionFound = ErrNoSessionFound
- DuplicateMessageIndex = ErrDuplicateMessageIndex
- WrongRoom = ErrWrongRoom
- DeviceKeyMismatch = ErrDeviceKeyMismatch
- RatchetError = ErrRatchetError
+ IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
+ NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
+ DuplicateMessageIndex = errors.New("duplicate megolm message index")
+ WrongRoom = errors.New("encrypted megolm event is not intended for this room")
+ DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
+ SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
+ RatchetError = errors.New("failed to ratchet session after use")
)
type megolmEvent struct {
@@ -55,30 +45,13 @@ var (
relatesToTopLevelPath = exgjson.Path("content", "m.relates_to")
)
-const sessionIDLength = 43
-
-func validateCiphertextCharacters(ciphertext []byte) bool {
- for _, b := range ciphertext {
- if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' {
- return false
- }
- }
- return true
-}
-
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
- return nil, ErrIncorrectEncryptedContentType
+ return nil, IncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
- return nil, ErrUnsupportedAlgorithm
- } else if len(content.MegolmCiphertext) < 74 {
- return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext))
- } else if len(content.SessionID) != sessionIDLength {
- return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID))
- } else if !validateCiphertextCharacters(content.MegolmCiphertext) {
- return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload)
+ return nil, UnsupportedAlgorithm
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
@@ -124,13 +97,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
- log.Debug().
- Stringer("session_sender_key", sess.SenderKey).
- Stringer("device_sender_key", device.IdentityKey).
- Stringer("session_signing_key", sess.SigningKey).
- Stringer("device_signing_key", device.SigningKey).
- Msg("Device keys don't match keys in session, marking as untrusted")
- trustLevel = id.TrustStateDeviceKeyMismatch
+ return nil, DeviceKeyMismatch
} else {
trustLevel, err = mach.ResolveTrustContext(ctx, device)
if err != nil {
@@ -180,7 +147,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
if err != nil {
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
} else if megolmEvt.RoomID != encryptionRoomID {
- return nil, ErrWrongRoom
+ return nil, WrongRoom
}
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
megolmEvt.Type.Class = event.StateEventType
@@ -213,7 +180,6 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
TrustSource: device,
ForwardedKeys: forwardedKeys,
WasEncrypted: true,
- EventSource: evt.Mautrix.EventSource | event.SourceDecrypted,
ReceivedAt: evt.Mautrix.ReceivedAt,
},
}, nil
@@ -235,19 +201,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
if decodeErr != nil {
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
- return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
+ return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
- return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
} else if !ok {
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
- return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
}
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
- return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
}
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
@@ -258,11 +224,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
- return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
+ return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
+ } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
+ return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
- if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
+ if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
}
@@ -270,7 +238,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
- return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
+ return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
}
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
@@ -322,24 +290,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
- return sess, plaintext, messageIndex, ErrRatchetError
+ return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
- return sess, plaintext, messageIndex, ErrRatchetError
+ return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
- return sess, plaintext, messageIndex, ErrRatchetError
+ return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
- return sess, plaintext, messageIndex, ErrRatchetError
+ return sess, plaintext, messageIndex, RatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}
diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go
index aea5e6dc..b737e4e1 100644
--- a/crypto/decryptolm.go
+++ b/crypto/decryptolm.go
@@ -17,36 +17,21 @@ import (
"time"
"github.com/rs/zerolog"
- "go.mau.fi/util/exerrors"
- "go.mau.fi/util/ptr"
- "maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var (
- ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
- ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
- ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
- ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
- ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
- ErrSenderMismatch = errors.New("mismatched sender in olm payload")
- ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
- ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
- ErrDuplicateMessage = errors.New("duplicate olm message")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- UnsupportedAlgorithm = ErrUnsupportedAlgorithm
- NotEncryptedForMe = ErrNotEncryptedForMe
- UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
- DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
- DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
- SenderMismatch = ErrSenderMismatch
- RecipientMismatch = ErrRecipientMismatch
- RecipientKeyMismatch = ErrRecipientKeyMismatch
+ UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
+ NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
+ UnsupportedOlmMessageType = errors.New("unsupported olm message type")
+ DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
+ DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
+ SenderMismatch = errors.New("mismatched sender in olm payload")
+ RecipientMismatch = errors.New("mismatched recipient in olm payload")
+ RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
+ ErrDuplicateMessage = errors.New("duplicate olm message")
)
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
@@ -68,13 +53,13 @@ type DecryptedOlmEvent struct {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
- return nil, ErrIncorrectEncryptedContentType
+ return nil, IncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmOlmV1 {
- return nil, ErrUnsupportedAlgorithm
+ return nil, UnsupportedAlgorithm
}
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
if !ok {
- return nil, ErrNotEncryptedForMe
+ return nil, NotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
@@ -90,7 +75,7 @@ type OlmEventKeys struct {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
- return nil, ErrUnsupportedOlmMessageType
+ return nil, UnsupportedOlmMessageType
}
log := mach.machOrContextLog(ctx).With().
@@ -114,11 +99,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
olmEvt.Type.Class = evt.Type.Class
if evt.Sender != olmEvt.Sender {
- return nil, ErrSenderMismatch
+ return nil, SenderMismatch
} else if mach.Client.UserID != olmEvt.Recipient {
- return nil, ErrRecipientMismatch
+ return nil, RecipientMismatch
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
- return nil, ErrRecipientKeyMismatch
+ return nil, RecipientKeyMismatch
}
if len(olmEvt.Content.VeryRaw) > 0 {
@@ -134,9 +119,6 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
func olmMessageHash(ciphertext string) ([32]byte, error) {
- if ciphertext == "" {
- return [32]byte{}, fmt.Errorf("empty ciphertext")
- }
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
return sha256.Sum256(ciphertextBytes), err
}
@@ -166,7 +148,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
if err != nil {
- if err == ErrDecryptionFailedWithMatchingSession {
+ if err == DecryptionFailedWithMatchingSession {
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
@@ -184,10 +166,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(log, sender, senderKey)
- return nil, ErrDecryptionFailedForNormalMessage
+ return nil, DecryptionFailedForNormalMessage
}
- accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
@@ -199,7 +180,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
log.Debug().
Hex("ciphertext_hash", ciphertextHash[:]).
- Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("olm_session_description", session.Describe()).
Msg("Created inbound olm session")
ctx = log.WithContext(ctx)
@@ -208,19 +188,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err = session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
- log.Debug().
- Hex("ciphertext_hash", ciphertextHash[:]).
- Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
- Str("ciphertext", ciphertext).
- Str("olm_session_description", session.Describe()).
- Msg("DEBUG: Failed to decrypt prekey olm message with newly created session")
- err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup)
- if err2 != nil {
- log.Debug().Err(err2).Msg("Goolm confirmed decryption failure")
- } else {
- log.Warn().Msg("Goolm decryption was successful after libolm failure?")
- }
-
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
}
@@ -238,23 +205,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
return plaintext, nil
}
-func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error {
- acc, err := account.AccountFromPickled(accountBackup, []byte("tmp"))
- if err != nil {
- return fmt.Errorf("failed to unpickle olm account: %w", err)
- }
- sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext)
- if err != nil {
- return fmt.Errorf("failed to create inbound session: %w", err)
- }
- _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey)
- if err != nil {
- // This is the expected result if libolm failed
- return fmt.Errorf("failed to decrypt with new session: %w", err)
- }
- return nil
-}
-
const MaxOlmSessionsPerDevice = 5
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
@@ -313,11 +263,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
if err != nil {
log.Warn().Err(err).
Hex("ciphertext_hash", ciphertextHash[:]).
- Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
- return nil, ErrDecryptionFailedWithMatchingSession
+ return nil, DecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
@@ -357,10 +306,10 @@ const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
log = log.With().Str("action", "unwedge olm session").Logger()
- ctx := log.WithContext(mach.backgroundCtx)
+ ctx := log.WithContext(mach.BackgroundCtx)
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
- delta := time.Since(prevUnwedge)
+ delta := time.Now().Sub(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
log.Debug().
Str("previous_recreation", delta.String()).
@@ -391,10 +340,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
return
}
- log.Debug().
- Time("last_created", lastCreatedAt).
- Stringer("device_id", deviceIdentity.DeviceID).
- Msg("Creating new Olm session")
+ log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
mach.devicesToUnwedgeLock.Lock()
mach.devicesToUnwedge[senderKey] = true
mach.devicesToUnwedgeLock.Unlock()
diff --git a/crypto/devicelist.go b/crypto/devicelist.go
index f0d2b129..a2116ed5 100644
--- a/crypto/devicelist.go
+++ b/crypto/devicelist.go
@@ -22,23 +22,14 @@ import (
)
var (
- ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
- ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
- ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
- ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
- ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
- ErrInvalidKeySignature = errors.New("invalid signature on device keys")
- ErrUserNotTracked = errors.New("user is not tracked")
-)
+ MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
+ MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
+ MismatchingSigningKey = errors.New("received update for device with different signing key")
+ NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
+ NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
+ InvalidKeySignature = errors.New("invalid signature on device keys")
-// Deprecated: use variables prefixed with Err
-var (
- MismatchingDeviceID = ErrMismatchingDeviceID
- MismatchingUserID = ErrMismatchingUserID
- MismatchingSigningKey = ErrMismatchingSigningKey
- NoSigningKeyFound = ErrNoSigningKeyFound
- NoIdentityKeyFound = ErrNoIdentityKeyFound
- InvalidKeySignature = ErrInvalidKeySignature
+ ErrUserNotTracked = errors.New("user is not tracked")
)
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
@@ -215,7 +206,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
for userID, devices := range resp.DeviceKeys {
- log := log.With().Stringer("user_id", userID).Logger()
+ log := log.With().Str("user_id", userID.String()).Logger()
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
@@ -231,7 +222,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
Msg("Updating devices in store")
changed := false
for deviceID, deviceKeys := range devices {
- log := log.With().Stringer("device_id", deviceID).Logger()
+ log := log.With().Str("device_id", deviceID.String()).Logger()
existing, ok := existingDevices[deviceID]
if !ok {
// New device
@@ -279,7 +270,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
}
}
for userID := range req.DeviceKeys {
- log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user")
+ log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
}
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
@@ -321,28 +312,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
if deviceID != deviceKeys.DeviceID {
- return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
+ return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
} else if userID != deviceKeys.UserID {
- return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
+ return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
}
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
if signingKey == "" {
- return nil, ErrNoSigningKeyFound
+ return nil, NoSigningKeyFound
} else if identityKey == "" {
- return nil, ErrNoIdentityKeyFound
+ return nil, NoIdentityKeyFound
}
if existing != nil && existing.SigningKey != signingKey {
- return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
+ return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
}
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
if err != nil {
return existing, fmt.Errorf("failed to verify signature: %w", err)
} else if !ok {
- return existing, ErrInvalidKeySignature
+ return existing, InvalidKeySignature
}
name, ok := deviceKeys.Unsigned["device_display_name"].(string)
diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go
index 88f9c8d4..804e15de 100644
--- a/crypto/encryptmegolm.go
+++ b/crypto/encryptmegolm.go
@@ -25,12 +25,8 @@ import (
)
var (
- ErrNoGroupSession = errors.New("no group session created")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- NoGroupSession = ErrNoGroupSession
+ AlreadyShared = errors.New("group session already shared")
+ NoGroupSession = errors.New("no group session created")
)
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
@@ -46,7 +42,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T {
return &result
}
-func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
+func getRelatesTo(content any) *event.RelatesTo {
contentJSON, ok := content.(json.RawMessage)
if ok {
return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to")
@@ -59,7 +55,7 @@ func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
if ok {
return relatable.OptionalGetRelatesTo()
}
- return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to")
+ return nil
}
func getMentions(content any) *event.Mentions {
@@ -87,20 +83,15 @@ type rawMegolmEvent struct {
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
func IsShareError(err error) bool {
- return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
+ return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
}
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
- if len(ciphertext) == 0 {
- return 0, fmt.Errorf("empty ciphertext")
- }
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
- } else if len(decoded) < 2+binary.MaxVarintLen64 {
- return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded))
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
@@ -130,7 +121,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
- return nil, ErrNoGroupSession
+ return nil, NoGroupSession
}
plaintext, err := json.Marshal(&rawMegolmEvent{
RoomID: roomID,
@@ -168,21 +159,12 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
- RelatesTo: getRelatesTo(content, plaintext),
+ RelatesTo: getRelatesTo(content),
// These are deprecated
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}
- if mach.MSC4392Relations && encrypted.RelatesTo != nil {
- // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content.
- // Other relations like threads are still left unencrypted.
- encrypted.RelatesTo.InReplyTo = nil
- encrypted.RelatesTo.IsFallingBack = false
- if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" {
- encrypted.RelatesTo = nil
- }
- }
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
@@ -227,8 +209,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
- mach.machOrContextLog(ctx).Debug().Stringer("room_id", roomID).Msg("Not re-sharing group session, already shared")
- return nil
+ return AlreadyShared
}
log := mach.machOrContextLog(ctx).With().
Str("room_id", roomID.String()).
@@ -252,7 +233,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
var fetchKeysForUsers []id.UserID
for _, userID := range users {
- log := log.With().Stringer("target_user_id", userID).Logger()
+ log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Err(err).Msg("Failed to get devices of user")
@@ -324,7 +305,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
toDeviceWithheld.Messages[userID] = withheld
}
- log := log.With().Stringer("target_user_id", userID).Logger()
+ log := log.With().Str("target_user_id", userID.String()).Logger()
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
log.Debug().
@@ -370,19 +351,26 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
- logUsers := zerolog.Dict()
for userID, sessions := range olmSessions {
if len(sessions) == 0 {
continue
}
- logDevices := zerolog.Dict()
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
+ log.Trace().
+ Stringer("target_user_id", userID).
+ Stringer("target_device_id", deviceID).
+ Stringer("target_identity_key", device.identity.IdentityKey).
+ Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
- logDevices.Str(string(deviceID), string(device.identity.IdentityKey))
deviceCount++
+ log.Debug().
+ Stringer("target_user_id", userID).
+ Stringer("target_device_id", deviceID).
+ Stringer("target_identity_key", device.identity.IdentityKey).
+ Msg("Encrypted group session for device")
if !mach.DisableSharedGroupSessionTracking {
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
if err != nil {
@@ -396,13 +384,11 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
}
}
}
- logUsers.Dict(string(userID), logDevices)
}
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
- Dict("destination_map", logUsers).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
return err
diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go
index 765307af..80b76dc5 100644
--- a/crypto/encryptolm.go
+++ b/crypto/encryptolm.go
@@ -96,19 +96,15 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
panic(err)
}
log := mach.machOrContextLog(ctx)
+ log.Debug().
+ Str("recipient_identity_key", recipient.IdentityKey.String()).
+ Str("olm_session_id", session.ID().String()).
+ Str("olm_session_description", session.Describe()).
+ Msg("Encrypting olm message")
msgType, ciphertext, err := session.Encrypt(plaintext)
if err != nil {
panic(err)
}
- ciphertextStr := string(ciphertext)
- ciphertextHash, _ := olmMessageHash(ciphertextStr)
- log.Debug().
- Stringer("event_type", evtType).
- Str("recipient_identity_key", recipient.IdentityKey.String()).
- Str("olm_session_id", session.ID().String()).
- Str("olm_session_description", session.Describe()).
- Hex("ciphertext_hash", ciphertextHash[:]).
- Msg("Encrypted olm message")
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
@@ -119,7 +115,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
OlmCiphertext: event.OlmCiphertexts{
recipient.IdentityKey: {
Type: msgType,
- Body: ciphertextStr,
+ Body: string(ciphertext),
},
},
}
diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go
index b48843a4..4da08a73 100644
--- a/crypto/goolm/account/account.go
+++ b/crypto/goolm/account/account.go
@@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
if err != nil {
return err
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
- return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
+ return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
return err
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go
index d0dec5f0..e1c9b452 100644
--- a/crypto/goolm/account/account_test.go
+++ b/crypto/goolm/account/account_test.go
@@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
account, err := account.NewAccount()
assert.NoError(t, err)
err = account.Unpickle(pickled, pickleKey)
- assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
+ assert.ErrorIs(t, err, olm.ErrBadVersion)
}
func TestLoopback(t *testing.T) {
diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go
index ec392d7e..c6b9e523 100644
--- a/crypto/goolm/account/register.go
+++ b/crypto/goolm/account/register.go
@@ -10,7 +10,7 @@ import (
"maunium.net/go/mautrix/crypto/olm"
)
-func Register() {
+func init() {
olm.InitNewAccount = func() (olm.Account, error) {
return NewAccount()
}
diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go
index 6e42d886..e9759501 100644
--- a/crypto/goolm/crypto/curve25519.go
+++ b/crypto/goolm/crypto/curve25519.go
@@ -53,7 +53,6 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
// SharedSecret returns the shared secret between the key pair and the given public key.
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
- // Note: the standard library checks that the output is non-zero
return c.PrivateKey.SharedSecret(pubKey)
}
diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go
index 2550f15e..9039c126 100644
--- a/crypto/goolm/crypto/curve25519_test.go
+++ b/crypto/goolm/crypto/curve25519_test.go
@@ -25,8 +25,6 @@ func TestCurve25519(t *testing.T) {
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
assert.NoError(t, err)
assert.Equal(t, fromPrivate, firstKeypair)
- _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength))
- assert.Error(t, err)
}
func TestCurve25519Case1(t *testing.T) {
diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go
index 58ee26f7..061a052a 100644
--- a/crypto/goolm/goolmbase64/base64.go
+++ b/crypto/goolm/goolmbase64/base64.go
@@ -4,8 +4,7 @@ import (
"encoding/base64"
)
-// These methods should only be used for raw byte operations, never with string conversion
-
+// Deprecated: base64.RawStdEncoding should be used directly
func Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
@@ -15,6 +14,7 @@ func Decode(input []byte) ([]byte, error) {
return decoded[:writtenBytes], nil
}
+// Deprecated: base64.RawStdEncoding should be used directly
func Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)
diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go
index f765391f..308e472c 100644
--- a/crypto/goolm/libolmpickle/picklejson.go
+++ b/crypto/goolm/libolmpickle/picklejson.go
@@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
}
}
if decrypted[0] != pickleVersion {
- return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
+ return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {
diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go
index b06756a9..a71cf302 100644
--- a/crypto/goolm/message/decoder.go
+++ b/crypto/goolm/message/decoder.go
@@ -3,9 +3,6 @@ package message
import (
"bytes"
"encoding/binary"
- "fmt"
-
- "maunium.net/go/mautrix/crypto/olm"
)
type Decoder struct {
@@ -23,8 +20,6 @@ func (d *Decoder) ReadVarInt() (uint64, error) {
func (d *Decoder) ReadVarBytes() ([]byte, error) {
if n, err := d.ReadVarInt(); err != nil {
return nil, err
- } else if n > uint64(d.Len()) {
- return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available())
} else {
out := make([]byte, n)
_, err = d.Read(out)
diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go
index c83540c1..c2a43b1f 100644
--- a/crypto/goolm/message/group_message.go
+++ b/crypto/goolm/message/group_message.go
@@ -2,12 +2,10 @@ package message
import (
"bytes"
- "fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
- "maunium.net/go/mautrix/crypto/olm"
)
const (
@@ -38,9 +36,6 @@ func (r *GroupMessage) Decode(input []byte) (err error) {
if err != nil {
return
}
- if r.Version != protocolVersion {
- return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
- }
for {
// Read Key
diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go
index b161a2d1..8bb6e0cd 100644
--- a/crypto/goolm/message/message.go
+++ b/crypto/goolm/message/message.go
@@ -2,12 +2,10 @@ package message
import (
"bytes"
- "fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
- "maunium.net/go/mautrix/crypto/olm"
)
const (
@@ -42,9 +40,6 @@ func (r *Message) Decode(input []byte) (err error) {
if err != nil {
return
}
- if r.Version != protocolVersion {
- return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
- }
for {
// Read Key
diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go
index 4e3d495d..22ebf9c3 100644
--- a/crypto/goolm/message/prekey_message.go
+++ b/crypto/goolm/message/prekey_message.go
@@ -1,7 +1,6 @@
package message
import (
- "fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/crypto"
@@ -23,11 +22,6 @@ type PreKeyMessage struct {
Message []byte `json:"message"`
}
-// TODO deduplicate constant with one in session/olm_session.go
-const (
- protocolVersion = 0x3
-)
-
// Decodes decodes the input and populates the corresponding fileds.
func (r *PreKeyMessage) Decode(input []byte) (err error) {
r.Version = 0
@@ -47,9 +41,6 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) {
}
return
}
- if r.Version != protocolVersion {
- return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
- }
for {
// Read Key
diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go
index d58dbb21..956868b2 100644
--- a/crypto/goolm/message/session_export.go
+++ b/crypto/goolm/message/session_export.go
@@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error {
return fmt.Errorf("decrypt: %w", olm.ErrBadInput)
}
if input[0] != sessionExportVersion {
- return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion)
+ return fmt.Errorf("decrypt: %w", olm.ErrBadVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go
index d04ef15a..16240945 100644
--- a/crypto/goolm/message/session_sharing.go
+++ b/crypto/goolm/message/session_sharing.go
@@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
}
s.PublicKey = publicKey
if input[0] != sessionSharingVersion {
- return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion)
+ return fmt.Errorf("verify: %w", olm.ErrBadVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go
index cdb20eb1..afb01f74 100644
--- a/crypto/goolm/pk/decryption.go
+++ b/crypto/goolm/pk/decryption.go
@@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error {
if pickledVersion == decryptionPickleVersionLibOlm {
return a.KeyPair.UnpickleLibOlm(decoder)
} else {
- return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm)
+ return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm)
}
}
diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go
index 2897d9b0..23f67ddf 100644
--- a/crypto/goolm/pk/encryption.go
+++ b/crypto/goolm/pk/encryption.go
@@ -37,9 +37,6 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat
return nil, nil, err
}
cipher, err := aessha2.NewAESSHA2(sharedSecret, nil)
- if err != nil {
- return nil, nil, err
- }
ciphertext, err = cipher.Encrypt(plaintext)
if err != nil {
return nil, nil, err
diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go
index 0e27b568..b7af6a5b 100644
--- a/crypto/goolm/pk/register.go
+++ b/crypto/goolm/pk/register.go
@@ -8,7 +8,7 @@ package pk
import "maunium.net/go/mautrix/crypto/olm"
-func Register() {
+func init() {
olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) {
return NewSigningFromSeed(seed)
}
diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go
index 9901ada8..229c9bd2 100644
--- a/crypto/goolm/ratchet/olm.go
+++ b/crypto/goolm/ratchet/olm.go
@@ -142,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
return nil, err
}
if message.Version != protocolVersion {
- return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion)
+ return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
}
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go
index 800f567f..80ed206b 100644
--- a/crypto/goolm/register.go
+++ b/crypto/goolm/register.go
@@ -7,23 +7,19 @@
package goolm
import (
- "maunium.net/go/mautrix/crypto/goolm/account"
- "maunium.net/go/mautrix/crypto/goolm/pk"
- "maunium.net/go/mautrix/crypto/goolm/session"
+ // Need to import these subpackages to ensure they are registered
+ _ "maunium.net/go/mautrix/crypto/goolm/account"
+ _ "maunium.net/go/mautrix/crypto/goolm/pk"
+ _ "maunium.net/go/mautrix/crypto/goolm/session"
+
"maunium.net/go/mautrix/crypto/olm"
)
-func Register() {
- olm.Driver = "goolm"
-
+func init() {
olm.GetVersion = func() (major, minor, patch uint8) {
return 3, 2, 15
}
olm.SetPickleKeyImpl = func(key []byte) {
panic("gob and json encoding is deprecated and not supported with goolm")
}
-
- account.Register()
- pk.Register()
- session.Register()
}
diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go
index 7ccbd26d..80dd71cc 100644
--- a/crypto/goolm/session/megolm_inbound_session.go
+++ b/crypto/goolm/session/megolm_inbound_session.go
@@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
}
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
// the counter is before our initial ratchet - we can't decode this
- return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex)
+ return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable)
}
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
copiedRatchet := o.InitialRatchet
@@ -126,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error)
return nil, 0, err
}
if msg.Version != protocolVersion {
- return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion)
+ return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
}
if msg.Ciphertext == nil || !msg.HasMessageIndex {
return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
@@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error {
return err
}
if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 {
- return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
+ return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
}
if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil {
diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go
index 7f923534..2b8e1c84 100644
--- a/crypto/goolm/session/megolm_outbound_session.go
+++ b/crypto/goolm/session/megolm_outbound_session.go
@@ -101,10 +101,8 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
- if err != nil {
- return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err)
- } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
- return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
+ if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
+ return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
}
if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil {
return err
diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go
index a1cb8d66..b99ab630 100644
--- a/crypto/goolm/session/olm_session.go
+++ b/crypto/goolm/session/olm_session.go
@@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received
msg := message.Message{}
err = msg.Decode(oneTimeMsg.Message)
if err != nil {
- return nil, fmt.Errorf("message decode: %w", err)
+ return nil, fmt.Errorf("Message decode: %w", err)
}
if len(msg.RatchetKey) == 0 {
- return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat)
+ return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat)
}
//Init Ratchet
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
@@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID {
copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey)
hash := sha256.Sum256(message)
- res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:]))
+ res := id.SessionID(goolmbase64.Encode(hash[:]))
return res
}
@@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput)
}
- decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext)
+ decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext))
if err != nil {
return nil, err
}
@@ -365,9 +365,6 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error {
func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
- if err != nil {
- return fmt.Errorf("unpickle olmSession: failed to read version: %w", err)
- }
var includesChainIndex bool
switch pickledVersion {
@@ -376,7 +373,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
case uint32(0x80000001):
includesChainIndex = true
default:
- return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
+ return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
}
if o.ReceivedMessage, err = decoder.ReadBool(); err != nil {
diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go
index b95a44ac..09ed42d4 100644
--- a/crypto/goolm/session/register.go
+++ b/crypto/goolm/session/register.go
@@ -10,11 +10,11 @@ import (
"maunium.net/go/mautrix/crypto/olm"
)
-func Register() {
+func init() {
// Inbound Session
olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
if len(key) == 0 {
key = []byte(" ")
@@ -23,13 +23,13 @@ func Register() {
}
olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
return NewMegolmInboundSession(sessionKey)
}
olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
return NewMegolmInboundSessionFromExport(sessionKey)
}
@@ -40,7 +40,7 @@ func Register() {
// Outbound Session
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
lenKey := len(key)
if lenKey == 0 {
diff --git a/crypto/keybackup.go b/crypto/keybackup.go
index 7b3c30db..d8b3d715 100644
--- a/crypto/keybackup.go
+++ b/crypto/keybackup.go
@@ -56,12 +56,11 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context,
// ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private
// key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing
// from a verified device belonging to the same user."
- if megolmBackupKey != nil {
- megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
- if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
- log.Debug().Msg("Key backup is trusted based on derived public key")
- return versionInfo, nil
- }
+ megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
+ if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
+ log.Debug().Msg("key backup is trusted based on derived public key")
+ return versionInfo, nil
+ } else {
log.Debug().
Stringer("expected_key", megolmBackupDerivedPublicKey).
Stringer("actual_key", versionInfo.AuthData.PublicKey).
@@ -200,14 +199,13 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving(
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,
- ForwardingChains: keyBackupData.ForwardingKeyChain,
+ ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
id: sessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
KeyBackupVersion: version,
- KeySource: id.KeySourceBackup,
}, nil
}
diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go
index fd6f105d..47616a20 100644
--- a/crypto/keyexport_test.go
+++ b/crypto/keyexport_test.go
@@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) {
))
data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess})
assert.NoError(t, err)
- assert.Len(t, data, 893)
+ assert.Len(t, data, 836)
}
diff --git a/crypto/keyimport.go b/crypto/keyimport.go
index 3ffc74a5..36ad6b9c 100644
--- a/crypto/keyimport.go
+++ b/crypto/keyimport.go
@@ -108,20 +108,19 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
return false, ErrMismatchingExportedSessionID
}
igs := &InboundGroupSession{
- Internal: igsInternal,
- SigningKey: session.SenderClaimedKeys.Ed25519,
- SenderKey: session.SenderKey,
- RoomID: session.RoomID,
+ Internal: igsInternal,
+ SigningKey: session.SenderClaimedKeys.Ed25519,
+ SenderKey: session.SenderKey,
+ RoomID: session.RoomID,
+ // TODO should we add something here to mark the signing key as unverified like key requests do?
ForwardingChains: session.ForwardingChains,
- KeySource: id.KeySourceImport,
- ReceivedAt: time.Now().UTC(),
+
+ ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
firstKnownIndex := igs.Internal.FirstKnownIndex()
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
- // We already have an equivalent or better session in the store, so don't override it,
- // but do notify the session received callback just in case.
- mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex())
+ // We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
diff --git a/crypto/keysharing.go b/crypto/keysharing.go
index 19a68c87..f1d427af 100644
--- a/crypto/keysharing.go
+++ b/crypto/keysharing.go
@@ -189,7 +189,6 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
- KeySource: id.KeySourceForward,
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
@@ -215,7 +214,6 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare
RoomID: request.RoomID,
Algorithm: request.Algorithm,
SessionID: request.SessionID,
- //lint:ignore SA1019 This is just echoing back the deprecated field
SenderKey: request.SenderKey,
Code: rejection.Code,
Reason: rejection.Reason,
@@ -265,14 +263,9 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev
log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing")
return &KeyShareRejectNoResponse
} else if !isShared {
- igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID)
- if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey {
- log.Debug().Msg("Rejecting key request for unshared session")
- return &KeyShareRejectNotRecipient
- }
- // Note: this case will also happen for redacted sessions and database errors
- log.Debug().Msg("Rejecting key request for session created by another device")
- return &KeyShareRejectNoResponse
+ // TODO differentiate session not shared with requester vs session not created by this device?
+ log.Debug().Msg("Rejecting key request for unshared session")
+ return &KeyShareRejectNotRecipient
}
log.Debug().Msg("Accepting key request for shared session")
return nil
@@ -330,9 +323,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
- if sender != mach.Client.UserID {
- mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
- }
+ mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
} else {
log.Error().Err(err).Msg("Failed to get group session to forward")
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
@@ -340,9 +331,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
return
} else if igs == nil {
log.Error().Msg("Didn't find group session to forward")
- if sender != mach.Client.UserID {
- mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
- }
+ mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
return
}
if internalID := igs.ID(); internalID != content.Body.SessionID {
@@ -367,7 +356,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
SessionID: igs.ID(),
SessionKey: string(exportedKey),
},
- SenderKey: igs.SenderKey,
+ SenderKey: content.Body.SenderKey,
ForwardingKeyChain: igs.ForwardingChains,
SenderClaimedKey: igs.SigningKey,
},
diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go
index 0350f083..cddce7ce 100644
--- a/crypto/libolm/account.go
+++ b/crypto/libolm/account.go
@@ -8,7 +8,6 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
- "runtime"
"unsafe"
"github.com/tidwall/gjson"
@@ -23,6 +22,18 @@ type Account struct {
mem []byte
}
+func init() {
+ olm.InitNewAccount = func() (olm.Account, error) {
+ return NewAccount()
+ }
+ olm.InitBlankAccount = func() olm.Account {
+ return NewBlankAccount()
+ }
+ olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) {
+ return AccountFromPickled(pickled, key)
+ }
+}
+
// Ensure that [Account] implements [olm.Account].
var _ olm.Account = (*Account)(nil)
@@ -33,7 +44,7 @@ var _ olm.Account = (*Account)(nil)
// "INVALID_BASE64".
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
a := NewBlankAccount()
return a, a.Unpickle(pickled, key)
@@ -42,7 +53,7 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) {
func NewBlankAccount() *Account {
memory := make([]byte, accountSize())
return &Account{
- int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))),
+ int: C.olm_account(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
@@ -53,13 +64,12 @@ func NewAccount() (*Account, error) {
random := make([]byte, a.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
- panic(olm.ErrNotEnoughGoRandom)
+ panic(olm.NotEnoughGoRandom)
}
ret := C.olm_create_account(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(random)),
+ unsafe.Pointer(&random[0]),
C.size_t(len(random)))
- runtime.KeepAlive(random)
if ret == errorVal() {
return nil, a.lastError()
} else {
@@ -128,14 +138,14 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
// supplied key.
func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.ErrNoKeyProvided
+ return nil, olm.NoKeyProvided
}
pickled := make([]byte, a.pickleLen())
r := C.olm_pickle_account(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
+ unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
return nil, a.lastError()
@@ -145,13 +155,13 @@ func (a *Account) Pickle(key []byte) ([]byte, error) {
func (a *Account) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.ErrNoKeyProvided
+ return olm.NoKeyProvided
}
r := C.olm_unpickle_account(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
+ unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
return a.lastError()
@@ -198,7 +208,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.ErrInputNotJSONString
+ return olm.InputNotJSONString
}
if a.int == nil {
*a = *NewBlankAccount()
@@ -211,7 +221,7 @@ func (a *Account) IdentityKeysJSON() ([]byte, error) {
identityKeys := make([]byte, a.identityKeysLen())
r := C.olm_account_identity_keys(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(identityKeys)),
+ unsafe.Pointer(&identityKeys[0]),
C.size_t(len(identityKeys)))
if r == errorVal() {
return nil, a.lastError()
@@ -235,16 +245,15 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
// Account.
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
- panic(olm.ErrEmptyInput)
+ panic(olm.EmptyInput)
}
signature := make([]byte, a.signatureLen())
r := C.olm_account_sign(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(message)),
+ unsafe.Pointer(&message[0]),
C.size_t(len(message)),
- unsafe.Pointer(unsafe.SliceData(signature)),
+ unsafe.Pointer(&signature[0]),
C.size_t(len(signature)))
- runtime.KeepAlive(message)
if r == errorVal() {
panic(a.lastError())
}
@@ -268,9 +277,8 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen())
r := C.olm_account_one_time_keys(
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)),
- C.size_t(len(oneTimeKeysJSON)),
- )
+ unsafe.Pointer(&oneTimeKeysJSON[0]),
+ C.size_t(len(oneTimeKeysJSON)))
if r == errorVal() {
return nil, a.lastError()
}
@@ -299,15 +307,13 @@ func (a *Account) GenOneTimeKeys(num uint) error {
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
_, err := rand.Read(random)
if err != nil {
- return olm.ErrNotEnoughGoRandom
+ return olm.NotEnoughGoRandom
}
r := C.olm_account_generate_one_time_keys(
(*C.OlmAccount)(a.int),
C.size_t(num),
- unsafe.Pointer(unsafe.SliceData(random)),
- C.size_t(len(random)),
- )
- runtime.KeepAlive(random)
+ unsafe.Pointer(&random[0]),
+ C.size_t(len(random)))
if r == errorVal() {
return a.lastError()
}
@@ -319,29 +325,23 @@ func (a *Account) GenOneTimeKeys(num uint) error {
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
s := NewBlankSession()
random := make([]byte, s.createOutboundRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
- panic(olm.ErrNotEnoughGoRandom)
+ panic(olm.NotEnoughGoRandom)
}
- theirIdentityKeyCopy := []byte(theirIdentityKey)
- theirOneTimeKeyCopy := []byte(theirOneTimeKey)
r := C.olm_create_outbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
- C.size_t(len(theirIdentityKeyCopy)),
- unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)),
- C.size_t(len(theirOneTimeKeyCopy)),
- unsafe.Pointer(unsafe.SliceData(random)),
- C.size_t(len(random)),
- )
- runtime.KeepAlive(random)
- runtime.KeepAlive(theirIdentityKeyCopy)
- runtime.KeepAlive(theirOneTimeKeyCopy)
+ unsafe.Pointer(&([]byte(theirIdentityKey)[0])),
+ C.size_t(len(theirIdentityKey)),
+ unsafe.Pointer(&([]byte(theirOneTimeKey)[0])),
+ C.size_t(len(theirOneTimeKey)),
+ unsafe.Pointer(&random[0]),
+ C.size_t(len(random)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -357,17 +357,14 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
if len(oneTimeKeyMsg) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
s := NewBlankSession()
- oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_create_inbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
- C.size_t(len(oneTimeKeyMsgCopy)),
- )
- runtime.KeepAlive(oneTimeKeyMsgCopy)
+ unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
+ C.size_t(len(oneTimeKeyMsg)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -383,21 +380,16 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
- theirIdentityKeyCopy := []byte(*theirIdentityKey)
- oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
s := NewBlankSession()
r := C.olm_create_inbound_session_from(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
- unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
- C.size_t(len(theirIdentityKeyCopy)),
- unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
- C.size_t(len(oneTimeKeyMsgCopy)),
- )
- runtime.KeepAlive(theirIdentityKeyCopy)
- runtime.KeepAlive(oneTimeKeyMsgCopy)
+ unsafe.Pointer(&([]byte(*theirIdentityKey)[0])),
+ C.size_t(len(*theirIdentityKey)),
+ unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
+ C.size_t(len(oneTimeKeyMsg)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -410,8 +402,7 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime
func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
r := C.olm_remove_one_time_keys(
(*C.OlmAccount)(a.int),
- (*C.OlmSession)(s.(*Session).int),
- )
+ (*C.OlmSession)(s.(*Session).int))
if r == errorVal() {
return a.lastError()
}
diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go
index 6fb5512b..9ca415ee 100644
--- a/crypto/libolm/error.go
+++ b/crypto/libolm/error.go
@@ -11,21 +11,21 @@ import (
)
var errorMap = map[string]error{
- "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom,
- "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall,
- "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion,
- "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat,
- "BAD_MESSAGE_MAC": olm.ErrBadMAC,
- "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID,
- "INVALID_BASE64": olm.ErrLibolmInvalidBase64,
- "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey,
- "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion,
- "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle,
- "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey,
- "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex,
- "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle,
- "BAD_SIGNATURE": olm.ErrBadSignature,
- "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall,
+ "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom,
+ "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall,
+ "BAD_MESSAGE_VERSION": olm.BadMessageVersion,
+ "BAD_MESSAGE_FORMAT": olm.BadMessageFormat,
+ "BAD_MESSAGE_MAC": olm.BadMessageMAC,
+ "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID,
+ "INVALID_BASE64": olm.InvalidBase64,
+ "BAD_ACCOUNT_KEY": olm.BadAccountKey,
+ "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion,
+ "CORRUPTED_PICKLE": olm.CorruptedPickle,
+ "BAD_SESSION_KEY": olm.BadSessionKey,
+ "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex,
+ "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle,
+ "BAD_SIGNATURE": olm.BadSignature,
+ "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall,
}
func convertError(errCode string) error {
diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go
index 8815ac32..1e25748d 100644
--- a/crypto/libolm/inboundgroupsession.go
+++ b/crypto/libolm/inboundgroupsession.go
@@ -7,7 +7,6 @@ import "C"
import (
"bytes"
"encoding/base64"
- "runtime"
"unsafe"
"maunium.net/go/mautrix/crypto/olm"
@@ -21,6 +20,21 @@ type InboundGroupSession struct {
mem []byte
}
+func init() {
+ olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
+ return InboundGroupSessionFromPickled(pickled, key)
+ }
+ olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
+ return NewInboundGroupSession(sessionKey)
+ }
+ olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
+ return InboundGroupSessionImport(sessionKey)
+ }
+ olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession {
+ return NewBlankInboundGroupSession()
+ }
+}
+
// Ensure that [InboundGroupSession] implements [olm.InboundGroupSession].
var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
@@ -31,7 +45,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
lenKey := len(key)
if lenKey == 0 {
@@ -48,15 +62,13 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
// "OLM_BAD_SESSION_KEY".
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_init_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))),
- C.size_t(len(sessionKey)),
- )
- runtime.KeepAlive(sessionKey)
+ (*C.uint8_t)(&sessionKey[0]),
+ C.size_t(len(sessionKey)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -69,15 +81,13 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
// error will be "OLM_BAD_SESSION_KEY".
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_import_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))),
- C.size_t(len(sessionKey)),
- )
- runtime.KeepAlive(sessionKey)
+ (*C.uint8_t)(&sessionKey[0]),
+ C.size_t(len(sessionKey)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -94,7 +104,7 @@ func inboundGroupSessionSize() uint {
func NewBlankInboundGroupSession() *InboundGroupSession {
memory := make([]byte, inboundGroupSessionSize())
return &InboundGroupSession{
- int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))),
+ int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
@@ -124,17 +134,15 @@ func (s *InboundGroupSession) pickleLen() uint {
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.ErrNoKeyProvided
+ return nil, olm.NoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
- C.size_t(len(pickled)),
- )
- runtime.KeepAlive(key)
+ unsafe.Pointer(&pickled[0]),
+ C.size_t(len(pickled)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -143,18 +151,16 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.ErrNoKeyProvided
+ return olm.NoKeyProvided
} else if len(pickled) == 0 {
- return olm.ErrEmptyInput
+ return olm.EmptyInput
}
r := C.olm_unpickle_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
- C.size_t(len(pickled)),
- )
- runtime.KeepAlive(key)
+ unsafe.Pointer(&pickled[0]),
+ C.size_t(len(pickled)))
if r == errorVal() {
return s.lastError()
}
@@ -200,7 +206,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.ErrInputNotJSONString
+ return olm.InputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankInboundGroupSession()
@@ -217,16 +223,14 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
// will be "BAD_MESSAGE_FORMAT".
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
if len(message) == 0 {
- return 0, olm.ErrEmptyInput
+ return 0, olm.EmptyInput
}
// olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it
- messageCopy := bytes.Clone(message)
+ message = bytes.Clone(message)
r := C.olm_group_decrypt_max_plaintext_length(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))),
- C.size_t(len(messageCopy)),
- )
- runtime.KeepAlive(messageCopy)
+ (*C.uint8_t)(&message[0]),
+ C.size_t(len(message)))
if r == errorVal() {
return 0, s.lastError()
}
@@ -244,24 +248,23 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
- return nil, 0, olm.ErrEmptyInput
+ return nil, 0, olm.EmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message)
if err != nil {
return nil, 0, err
}
- messageCopy := bytes.Clone(message)
+ messageCopy := make([]byte, len(message))
+ copy(messageCopy, message)
plaintext := make([]byte, decryptMaxPlaintextLen)
var messageIndex uint32
r := C.olm_group_decrypt(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))),
+ (*C.uint8_t)(&messageCopy[0]),
C.size_t(len(messageCopy)),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))),
+ (*C.uint8_t)(&plaintext[0]),
C.size_t(len(plaintext)),
- (*C.uint32_t)(unsafe.Pointer(&messageIndex)),
- )
- runtime.KeepAlive(messageCopy)
+ (*C.uint32_t)(&messageIndex))
if r == errorVal() {
return nil, 0, s.lastError()
}
@@ -278,9 +281,8 @@ func (s *InboundGroupSession) ID() id.SessionID {
sessionID := make([]byte, s.sessionIdLen())
r := C.olm_inbound_group_session_id(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))),
- C.size_t(len(sessionID)),
- )
+ (*C.uint8_t)(&sessionID[0]),
+ C.size_t(len(sessionID)))
if r == errorVal() {
panic(s.lastError())
}
@@ -316,10 +318,9 @@ func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
key := make([]byte, s.exportLen())
r := C.olm_export_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))),
+ (*C.uint8_t)(&key[0]),
C.size_t(len(key)),
- C.uint32_t(messageIndex),
- )
+ C.uint32_t(messageIndex))
if r == errorVal() {
return nil, s.lastError()
}
diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go
index ca5b68f7..a21f8d4a 100644
--- a/crypto/libolm/outboundgroupsession.go
+++ b/crypto/libolm/outboundgroupsession.go
@@ -7,7 +7,6 @@ import "C"
import (
"crypto/rand"
"encoding/base64"
- "runtime"
"unsafe"
"maunium.net/go/mautrix/crypto/olm"
@@ -21,6 +20,18 @@ type OutboundGroupSession struct {
mem []byte
}
+func init() {
+ olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
+ if len(pickled) == 0 {
+ return nil, olm.EmptyInput
+ }
+ s := NewBlankOutboundGroupSession()
+ return s, s.Unpickle(pickled, key)
+ }
+ olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() }
+ olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() }
+}
+
// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession].
var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil)
@@ -33,10 +44,8 @@ func NewOutboundGroupSession() (*OutboundGroupSession, error) {
}
r := C.olm_init_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))),
- C.size_t(len(random)),
- )
- runtime.KeepAlive(random)
+ (*C.uint8_t)(&random[0]),
+ C.size_t(len(random)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -53,7 +62,7 @@ func outboundGroupSessionSize() uint {
func NewBlankOutboundGroupSession() *OutboundGroupSession {
memory := make([]byte, outboundGroupSessionSize())
return &OutboundGroupSession{
- int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))),
+ int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
@@ -84,17 +93,15 @@ func (s *OutboundGroupSession) pickleLen() uint {
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.ErrNoKeyProvided
+ return nil, olm.NoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
- C.size_t(len(pickled)),
- )
- runtime.KeepAlive(key)
+ unsafe.Pointer(&pickled[0]),
+ C.size_t(len(pickled)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -103,17 +110,14 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.ErrNoKeyProvided
+ return olm.NoKeyProvided
}
r := C.olm_unpickle_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
- C.size_t(len(pickled)),
- )
- runtime.KeepAlive(pickled)
- runtime.KeepAlive(key)
+ unsafe.Pointer(&pickled[0]),
+ C.size_t(len(pickled)))
if r == errorVal() {
return s.lastError()
}
@@ -159,7 +163,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.ErrInputNotJSONString
+ return olm.InputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
@@ -183,17 +187,15 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if len(plaintext) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_group_encrypt(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))),
+ (*C.uint8_t)(&plaintext[0]),
C.size_t(len(plaintext)),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))),
- C.size_t(len(message)),
- )
- runtime.KeepAlive(plaintext)
+ (*C.uint8_t)(&message[0]),
+ C.size_t(len(message)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -210,9 +212,8 @@ func (s *OutboundGroupSession) ID() id.SessionID {
sessionID := make([]byte, s.sessionIdLen())
r := C.olm_outbound_group_session_id(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))),
- C.size_t(len(sessionID)),
- )
+ (*C.uint8_t)(&sessionID[0]),
+ C.size_t(len(sessionID)))
if r == errorVal() {
panic(s.lastError())
}
@@ -235,9 +236,8 @@ func (s *OutboundGroupSession) Key() string {
sessionKey := make([]byte, s.sessionKeyLen())
r := C.olm_outbound_group_session_key(
(*C.OlmOutboundGroupSession)(s.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))),
- C.size_t(len(sessionKey)),
- )
+ (*C.uint8_t)(&sessionKey[0]),
+ C.size_t(len(sessionKey)))
if r == errorVal() {
panic(s.lastError())
}
diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go
index 2683cf15..db8d35c5 100644
--- a/crypto/libolm/pk.go
+++ b/crypto/libolm/pk.go
@@ -14,7 +14,6 @@ import "C"
import (
"crypto/rand"
"encoding/json"
- "runtime"
"unsafe"
"github.com/tidwall/sjson"
@@ -35,6 +34,16 @@ type PKSigning struct {
// Ensure that [PKSigning] implements [olm.PKSigning].
var _ olm.PKSigning = (*PKSigning)(nil)
+func init() {
+ olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() }
+ olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) {
+ return NewPKSigningFromSeed(seed)
+ }
+ olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) {
+ return NewPkDecryption(privateKey)
+ }
+}
+
func pkSigningSize() uint {
return uint(C.olm_pk_signing_size())
}
@@ -54,7 +63,7 @@ func pkSigningSignatureLength() uint {
func newBlankPKSigning() *PKSigning {
memory := make([]byte, pkSigningSize())
return &PKSigning{
- int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))),
+ int: C.olm_pk_signing(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
@@ -64,14 +73,9 @@ func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) {
p := newBlankPKSigning()
p.clear()
pubKey := make([]byte, pkSigningPublicKeyLength())
- r := C.olm_pk_signing_key_from_seed(
- (*C.OlmPkSigning)(p.int),
- unsafe.Pointer(unsafe.SliceData(pubKey)),
- C.size_t(len(pubKey)),
- unsafe.Pointer(unsafe.SliceData(seed)),
- C.size_t(len(seed)),
- )
- if r == errorVal() {
+ if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int),
+ unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
+ unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() {
return nil, p.lastError()
}
p.publicKey = id.Ed25519(pubKey)
@@ -86,7 +90,7 @@ func NewPKSigning() (*PKSigning, error) {
seed := make([]byte, pkSigningSeedLength())
_, err := rand.Read(seed)
if err != nil {
- panic(olm.ErrNotEnoughGoRandom)
+ panic(olm.NotEnoughGoRandom)
}
pk, err := NewPKSigningFromSeed(seed)
return pk, err
@@ -108,15 +112,8 @@ func (p *PKSigning) clear() {
// Sign creates a signature for the given message using this key.
func (p *PKSigning) Sign(message []byte) ([]byte, error) {
signature := make([]byte, pkSigningSignatureLength())
- r := C.olm_pk_sign(
- (*C.OlmPkSigning)(p.int),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))),
- C.size_t(len(message)),
- (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))),
- C.size_t(len(signature)),
- )
- runtime.KeepAlive(message)
- if r == errorVal() {
+ if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)),
+ (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() {
return nil, p.lastError()
}
return signature, nil
@@ -160,21 +157,15 @@ func pkDecryptionPublicKeySize() uint {
func NewPkDecryption(privateKey []byte) (*PKDecryption, error) {
memory := make([]byte, pkDecryptionSize())
p := &PKDecryption{
- int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))),
+ int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])),
mem: memory,
}
p.clear()
pubKey := make([]byte, pkDecryptionPublicKeySize())
- r := C.olm_pk_key_from_private(
- (*C.OlmPkDecryption)(p.int),
- unsafe.Pointer(unsafe.SliceData(pubKey)),
- C.size_t(len(pubKey)),
- unsafe.Pointer(unsafe.SliceData(privateKey)),
- C.size_t(len(privateKey)),
- )
- runtime.KeepAlive(privateKey)
- if r == errorVal() {
+ if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int),
+ unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
+ unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() {
return nil, p.lastError()
}
p.publicKey = pubKey
@@ -187,26 +178,14 @@ func (p *PKDecryption) PublicKey() id.Curve25519 {
}
func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
- maxPlaintextLength := uint(C.olm_pk_max_plaintext_length(
- (*C.OlmPkDecryption)(p.int),
- C.size_t(len(ciphertext)),
- ))
+ maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext))))
plaintext := make([]byte, maxPlaintextLength)
- size := C.olm_pk_decrypt(
- (*C.OlmPkDecryption)(p.int),
- unsafe.Pointer(unsafe.SliceData(ephemeralKey)),
- C.size_t(len(ephemeralKey)),
- unsafe.Pointer(unsafe.SliceData(mac)),
- C.size_t(len(mac)),
- unsafe.Pointer(unsafe.SliceData(ciphertext)),
- C.size_t(len(ciphertext)),
- unsafe.Pointer(unsafe.SliceData(plaintext)),
- C.size_t(len(plaintext)),
- )
- runtime.KeepAlive(ephemeralKey)
- runtime.KeepAlive(mac)
- runtime.KeepAlive(ciphertext)
+ size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int),
+ unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)),
+ unsafe.Pointer(&mac[0]), C.size_t(len(mac)),
+ unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)),
+ unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext)))
if size == errorVal() {
return nil, p.lastError()
}
diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go
index ddf84613..a423a7d0 100644
--- a/crypto/libolm/register.go
+++ b/crypto/libolm/register.go
@@ -3,73 +3,19 @@ package libolm
// #cgo LDFLAGS: -lolm -lstdc++
// #include
import "C"
-import (
- "unsafe"
-
- "maunium.net/go/mautrix/crypto/olm"
-)
+import "maunium.net/go/mautrix/crypto/olm"
var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm")
-func Register() {
- olm.Driver = "libolm"
-
+func init() {
olm.GetVersion = func() (major, minor, patch uint8) {
C.olm_get_library_version(
- (*C.uint8_t)(unsafe.Pointer(&major)),
- (*C.uint8_t)(unsafe.Pointer(&minor)),
- (*C.uint8_t)(unsafe.Pointer(&patch)))
+ (*C.uint8_t)(&major),
+ (*C.uint8_t)(&minor),
+ (*C.uint8_t)(&patch))
return 3, 2, 15
}
olm.SetPickleKeyImpl = func(key []byte) {
pickleKey = key
}
-
- olm.InitNewAccount = func() (olm.Account, error) {
- return NewAccount()
- }
- olm.InitBlankAccount = func() olm.Account {
- return NewBlankAccount()
- }
- olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) {
- return AccountFromPickled(pickled, key)
- }
-
- olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) {
- return SessionFromPickled(pickled, key)
- }
- olm.InitNewBlankSession = func() olm.Session {
- return NewBlankSession()
- }
-
- olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() }
- olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) {
- return NewPKSigningFromSeed(seed)
- }
- olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) {
- return NewPkDecryption(privateKey)
- }
-
- olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
- return InboundGroupSessionFromPickled(pickled, key)
- }
- olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
- return NewInboundGroupSession(sessionKey)
- }
- olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
- return InboundGroupSessionImport(sessionKey)
- }
- olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession {
- return NewBlankInboundGroupSession()
- }
-
- olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
- if len(pickled) == 0 {
- return nil, olm.ErrEmptyInput
- }
- s := NewBlankOutboundGroupSession()
- return s, s.Unpickle(pickled, key)
- }
- olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() }
- olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() }
}
diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go
index 1441df26..4cc22809 100644
--- a/crypto/libolm/session.go
+++ b/crypto/libolm/session.go
@@ -23,7 +23,6 @@ import "C"
import (
"crypto/rand"
"encoding/base64"
- "runtime"
"unsafe"
"maunium.net/go/mautrix/crypto/olm"
@@ -39,6 +38,15 @@ type Session struct {
// Ensure that [Session] implements [olm.Session].
var _ olm.Session = (*Session)(nil)
+func init() {
+ olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) {
+ return SessionFromPickled(pickled, key)
+ }
+ olm.InitNewBlankSession = func() olm.Session {
+ return NewBlankSession()
+ }
+}
+
// sessionSize is the size of a session object in bytes.
func sessionSize() uint {
return uint(C.olm_session_size())
@@ -51,7 +59,7 @@ func sessionSize() uint {
// "INVALID_BASE64".
func SessionFromPickled(pickled, key []byte) (*Session, error) {
if len(pickled) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
s := NewBlankSession()
return s, s.Unpickle(pickled, key)
@@ -60,7 +68,7 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) {
func NewBlankSession() *Session {
memory := make([]byte, sessionSize())
return &Session{
- int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))),
+ int: C.olm_session(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
@@ -118,16 +126,13 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
// will be "BAD_MESSAGE_FORMAT".
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
if len(message) == 0 {
- return 0, olm.ErrEmptyInput
+ return 0, olm.EmptyInput
}
- messageCopy := []byte(message)
r := C.olm_decrypt_max_plaintext_length(
(*C.OlmSession)(s.int),
C.size_t(msgType),
- unsafe.Pointer(unsafe.SliceData((messageCopy))),
- C.size_t(len(messageCopy)),
- )
- runtime.KeepAlive(messageCopy)
+ unsafe.Pointer(C.CString(message)),
+ C.size_t(len(message)))
if r == errorVal() {
return 0, s.lastError()
}
@@ -138,16 +143,15 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
// supplied key.
func (s *Session) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.ErrNoKeyProvided
+ return nil, olm.NoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_session(
(*C.OlmSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
+ unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
- runtime.KeepAlive(key)
if r == errorVal() {
panic(s.lastError())
}
@@ -158,16 +162,14 @@ func (s *Session) Pickle(key []byte) ([]byte, error) {
// provided key. This function mutates the input pickled data slice.
func (s *Session) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.ErrNoKeyProvided
+ return olm.NoKeyProvided
}
r := C.olm_unpickle_session(
(*C.OlmSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(key)),
+ unsafe.Pointer(&key[0]),
C.size_t(len(key)),
- unsafe.Pointer(unsafe.SliceData(pickled)),
+ unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
- runtime.KeepAlive(pickled)
- runtime.KeepAlive(key)
if r == errorVal() {
return s.lastError()
}
@@ -213,7 +215,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.ErrInputNotJSONString
+ return olm.InputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankSession()
@@ -227,9 +229,8 @@ func (s *Session) ID() id.SessionID {
sessionID := make([]byte, s.idLen())
r := C.olm_session_id(
(*C.OlmSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(sessionID)),
- C.size_t(len(sessionID)),
- )
+ unsafe.Pointer(&sessionID[0]),
+ C.size_t(len(sessionID)))
if r == errorVal() {
panic(s.lastError())
}
@@ -256,15 +257,12 @@ func (s *Session) HasReceivedMessage() bool {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
if len(oneTimeKeyMsg) == 0 {
- return false, olm.ErrEmptyInput
+ return false, olm.EmptyInput
}
- oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_matches_inbound_session(
(*C.OlmSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
- C.size_t(len(oneTimeKeyMsgCopy)),
- )
- runtime.KeepAlive(oneTimeKeyMsgCopy)
+ unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]),
+ C.size_t(len(oneTimeKeyMsg)))
if r == 1 {
return true, nil
} else if r == 0 {
@@ -284,19 +282,14 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
- return false, olm.ErrEmptyInput
+ return false, olm.EmptyInput
}
- theirIdentityKeyCopy := []byte(theirIdentityKey)
- oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_matches_inbound_session_from(
(*C.OlmSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
- C.size_t(len(theirIdentityKeyCopy)),
- unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
- C.size_t(len(oneTimeKeyMsgCopy)),
- )
- runtime.KeepAlive(theirIdentityKeyCopy)
- runtime.KeepAlive(oneTimeKeyMsgCopy)
+ unsafe.Pointer(&([]byte(theirIdentityKey))[0]),
+ C.size_t(len(theirIdentityKey)),
+ unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]),
+ C.size_t(len(oneTimeKeyMsg)))
if r == 1 {
return true, nil
} else if r == 0 {
@@ -325,28 +318,25 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
- return 0, nil, olm.ErrEmptyInput
+ return 0, nil, olm.EmptyInput
}
// Make the slice be at least length 1
random := make([]byte, s.encryptRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
// TODO can we just return err here?
- return 0, nil, olm.ErrNotEnoughGoRandom
+ return 0, nil, olm.NotEnoughGoRandom
}
messageType := s.EncryptMsgType()
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_encrypt(
(*C.OlmSession)(s.int),
- unsafe.Pointer(unsafe.SliceData(plaintext)),
+ unsafe.Pointer(&plaintext[0]),
C.size_t(len(plaintext)),
- unsafe.Pointer(unsafe.SliceData(random)),
+ unsafe.Pointer(&random[0]),
C.size_t(len(random)),
- unsafe.Pointer(unsafe.SliceData(message)),
- C.size_t(len(message)),
- )
- runtime.KeepAlive(plaintext)
- runtime.KeepAlive(random)
+ unsafe.Pointer(&message[0]),
+ C.size_t(len(message)))
if r == errorVal() {
return 0, nil, s.lastError()
}
@@ -362,7 +352,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
// "BAD_MESSAGE_MAC".
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
- return nil, olm.ErrEmptyInput
+ return nil, olm.EmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType)
if err != nil {
@@ -373,12 +363,10 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error)
r := C.olm_decrypt(
(*C.OlmSession)(s.int),
C.size_t(msgType),
- unsafe.Pointer(unsafe.SliceData(messageCopy)),
+ unsafe.Pointer(&(messageCopy)[0]),
C.size_t(len(messageCopy)),
- unsafe.Pointer(unsafe.SliceData(plaintext)),
- C.size_t(len(plaintext)),
- )
- runtime.KeepAlive(messageCopy)
+ unsafe.Pointer(&plaintext[0]),
+ C.size_t(len(plaintext)))
if r == errorVal() {
return nil, s.lastError()
}
@@ -395,7 +383,6 @@ func (s *Session) Describe() string {
C.meowlm_session_describe(
(*C.OlmSession)(s.int),
desc,
- C.size_t(maxDescribeSize),
- )
+ C.size_t(maxDescribeSize))
return C.GoString(desc)
}
diff --git a/crypto/machine.go b/crypto/machine.go
index fa051f94..cac91bf8 100644
--- a/crypto/machine.go
+++ b/crypto/machine.go
@@ -15,12 +15,10 @@ import (
"time"
"github.com/rs/zerolog"
- "go.mau.fi/util/ptr"
"go.mau.fi/util/exzerolog"
"maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -35,11 +33,9 @@ type OlmMachine struct {
CryptoStore Store
StateStore StateStore
- backgroundCtx context.Context
- cancelBackgroundCtx context.CancelFunc
+ BackgroundCtx context.Context
PlaintextMentions bool
- MSC4392Relations bool
AllowEncryptedState bool
// Never ask the server for keys automatically as a side effect during Megolm decryption.
@@ -124,6 +120,8 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor
CryptoStore: cryptoStore,
StateStore: stateStore,
+ BackgroundCtx: context.Background(),
+
SendKeysMinTrust: id.TrustStateUnset,
ShareKeysMinTrust: id.TrustStateCrossSignedTOFU,
@@ -136,7 +134,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor
recentlyUnwedged: make(map[id.IdentityKey]time.Time),
secretListeners: make(map[string]chan<- string),
}
- mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background())
mach.AllowKeyShare = mach.defaultAllowKeyShare
return mach
}
@@ -149,11 +146,6 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
-func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) {
- mach.cancelBackgroundCtx()
- mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx)
-}
-
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
@@ -164,23 +156,9 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) {
if mach.account == nil {
mach.account = NewOlmAccount()
}
- zerolog.Ctx(ctx).Debug().
- Str("machine_ptr", fmt.Sprintf("%p", mach)).
- Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)).
- Str("olm_driver", olm.Driver).
- Msg("Loaded olm account")
return nil
}
-func (mach *OlmMachine) Destroy() {
- mach.Log.Debug().
- Str("machine_ptr", fmt.Sprintf("%p", mach)).
- Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)).
- Msg("Destroying olm machine")
- mach.cancelBackgroundCtx()
- // TODO actually destroy something?
-}
-
func (mach *OlmMachine) saveAccount(ctx context.Context) error {
err := mach.CryptoStore.PutAccount(ctx, mach.account)
if err != nil {
@@ -206,7 +184,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error {
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
start := time.Now()
return func() {
- duration := time.Since(start)
+ duration := time.Now().Sub(start)
if duration > expectedDuration {
zerolog.Ctx(ctx).Warn().
Str("action", thing).
@@ -383,7 +361,7 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event)
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
if err != nil {
- mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session")
+ mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
}
}
@@ -603,7 +581,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
- log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session")
+ log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex())
@@ -730,7 +708,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
start := time.Now()
mach.otkUploadLock.Lock()
defer mach.otkUploadLock.Unlock()
- if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) {
+ if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 {
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count")
resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{})
if err != nil {
diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go
deleted file mode 100644
index fd40d795..00000000
--- a/crypto/machine_bench_test.go
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package crypto_test
-
-import (
- "context"
- "fmt"
- "math/rand/v2"
- "testing"
-
- "github.com/rs/zerolog"
- globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
- "github.com/stretchr/testify/require"
-
- "maunium.net/go/mautrix/crypto/cryptohelper"
- "maunium.net/go/mautrix/id"
- "maunium.net/go/mautrix/mockserver"
-)
-
-func randomDeviceCount(r *rand.Rand) int {
- k := 1
- for k < 10 && r.IntN(3) > 0 {
- k++
- }
- return k
-}
-
-func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) {
- globallog.Logger = zerolog.Nop()
- server := mockserver.Create(b)
- server.PopOTKs = false
- server.MemoryStore = false
- var i int
- var shareTargets []id.UserID
- r := rand.New(rand.NewPCG(293, 0))
- var totalDeviceCount int
- for i = 1; i < 1000; i++ {
- userID := id.UserID(fmt.Sprintf("@user%d:localhost", i))
- deviceCount := randomDeviceCount(r)
- for j := 0; j < deviceCount; j++ {
- client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j)))
- mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
- keysCache, err := mach.GenerateCrossSigningKeys()
- require.NoError(b, err)
- err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
- require.NoError(b, err)
- }
- totalDeviceCount += deviceCount
- shareTargets = append(shareTargets, userID)
- }
- for b.Loop() {
- client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i)))
- mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine()
- keysCache, err := mach.GenerateCrossSigningKeys()
- require.NoError(b, err)
- err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil)
- require.NoError(b, err)
- err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets)
- require.NoError(b, err)
- i++
- }
- fmt.Println(totalDeviceCount, "devices total")
-}
diff --git a/crypto/machine_test.go b/crypto/machine_test.go
index 872c3ac4..59c86236 100644
--- a/crypto/machine_test.go
+++ b/crypto/machine_test.go
@@ -36,15 +36,20 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID,
func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
client, err := mautrix.NewClient("http://localhost", userID, "token")
- require.NoError(t, err, "Error creating client")
+ if err != nil {
+ t.Fatalf("Error creating client: %v", err)
+ }
client.DeviceID = "device1"
gobStore := NewMemoryStore(nil)
- require.NoError(t, err, "Error creating Gob store")
+ if err != nil {
+ t.Fatalf("Error creating Gob store: %v", err)
+ }
machine := NewOlmMachine(client, nil, gobStore, mockStateStore{})
- err = machine.Load(context.TODO())
- require.NoError(t, err, "Error creating account")
+ if err := machine.Load(context.TODO()); err != nil {
+ t.Fatalf("Error creating account: %v", err)
+ }
return machine
}
@@ -77,7 +82,9 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
// create outbound olm session for sending machine using OTK
olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)
- require.NoError(t, err, "Error creating outbound olm session")
+ if err != nil {
+ t.Errorf("Failed to create outbound olm session: %v", err)
+ }
// store sender device identity in receiving machine store
machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{
@@ -114,21 +121,29 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
Type: event.ToDeviceEncrypted,
Sender: "user1",
}, senderKey, content.Type, content.Body)
- require.NoError(t, err, "Error decrypting olm ciphertext")
-
+ if err != nil {
+ t.Errorf("Error decrypting olm content: %v", err)
+ }
// store room key in new inbound group session
roomKeyEvt := decrypted.Content.AsRoomKey()
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false)
- require.NoError(t, err, "Error creating inbound group session")
- err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs)
- require.NoError(t, err, "Error storing inbound group session")
+ if err != nil {
+ t.Errorf("Error creating inbound megolm session: %v", err)
+ }
+ if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil {
+ t.Errorf("Error storing inbound megolm session: %v", err)
+ }
}
// encrypt event with megolm session in sending machine
eventContent := map[string]string{"hello": "world"}
encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
- require.NoError(t, err, "Error encrypting megolm event")
- assert.Equal(t, 1, megolmOutSession.MessageCount)
+ if err != nil {
+ t.Errorf("Error encrypting megolm event: %v", err)
+ }
+ if megolmOutSession.MessageCount != 1 {
+ t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount)
+ }
encryptedEvt := &event.Event{
Content: event.Content{Parsed: encryptedEvtContent},
@@ -140,12 +155,22 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
// decrypt event on receiving machine and confirm
decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt)
- require.NoError(t, err, "Error decrypting megolm event")
- assert.Equal(t, event.EventMessage, decryptedEvt.Type)
- assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"])
+ if err != nil {
+ t.Errorf("Error decrypting megolm event: %v", err)
+ }
+ if decryptedEvt.Type != event.EventMessage {
+ t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type)
+ }
+ if decryptedEvt.Content.Raw["hello"] != "world" {
+ t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw)
+ }
machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
- assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message")
+ if megolmOutSession.Expired() {
+ t.Error("Megolm outbound session expired before 3rd message")
+ }
machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent)
- assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message")
+ if !megolmOutSession.Expired() {
+ t.Error("Megolm outbound session not expired after 3rd message")
+ }
}
diff --git a/crypto/olm/account.go b/crypto/olm/account.go
index 2ec5dd70..68393e8a 100644
--- a/crypto/olm/account.go
+++ b/crypto/olm/account.go
@@ -87,8 +87,6 @@ type Account interface {
RemoveOneTimeKeys(s Session) error
}
-var Driver = "none"
-
var InitBlankAccount func() Account
var InitNewAccount func() (Account, error)
var InitNewAccountFromPickled func(pickled, key []byte) (Account, error)
diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go
index 9e522b2a..957d7928 100644
--- a/crypto/olm/errors.go
+++ b/crypto/olm/errors.go
@@ -10,67 +10,50 @@ import "errors"
// Those are the most common used errors
var (
- ErrBadSignature = errors.New("bad signature")
- ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)")
- ErrBadMessageFormat = errors.New("the message couldn't be decoded")
- ErrBadVerification = errors.New("bad verification")
- ErrWrongProtocolVersion = errors.New("wrong protocol version")
- ErrEmptyInput = errors.New("empty input")
- ErrNoKeyProvided = errors.New("no key provided")
- ErrBadMessageKeyID = errors.New("the message references an unknown key ID")
- ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
- ErrMsgIndexTooHigh = errors.New("message index too high")
- ErrProtocolViolation = errors.New("not protocol message order")
- ErrMessageKeyNotFound = errors.New("message key not found")
- ErrChainTooHigh = errors.New("chain index too high")
- ErrBadInput = errors.New("bad input")
- ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version")
- ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version")
- ErrInputToSmall = errors.New("input too small (truncated?)")
+ ErrBadSignature = errors.New("bad signature")
+ ErrBadMAC = errors.New("bad mac")
+ ErrBadMessageFormat = errors.New("bad message format")
+ ErrBadVerification = errors.New("bad verification")
+ ErrWrongProtocolVersion = errors.New("wrong protocol version")
+ ErrEmptyInput = errors.New("empty input")
+ ErrNoKeyProvided = errors.New("no key")
+ ErrBadMessageKeyID = errors.New("bad message key id")
+ ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
+ ErrMsgIndexTooHigh = errors.New("message index too high")
+ ErrProtocolViolation = errors.New("not protocol message order")
+ ErrMessageKeyNotFound = errors.New("message key not found")
+ ErrChainTooHigh = errors.New("chain index too high")
+ ErrBadInput = errors.New("bad input")
+ ErrBadVersion = errors.New("wrong version")
+ ErrWrongPickleVersion = errors.New("wrong pickle version")
+ ErrInputToSmall = errors.New("input too small (truncated?)")
+ ErrOverflow = errors.New("overflow")
)
// Error codes from go-olm
var (
- ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
- ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
+ EmptyInput = errors.New("empty input")
+ NoKeyProvided = errors.New("no pickle key provided")
+ NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
+ SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
+ InputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// Error codes from olm code
var (
- ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid")
-
- ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied")
- ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small")
- ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid")
- ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded")
- ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
- ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- EmptyInput = ErrEmptyInput
- BadSignature = ErrBadSignature
- InvalidBase64 = ErrLibolmInvalidBase64
- BadMessageKeyID = ErrBadMessageKeyID
- BadMessageFormat = ErrBadMessageFormat
- BadMessageVersion = ErrWrongProtocolVersion
- BadMessageMAC = ErrBadMAC
- UnknownPickleVersion = ErrUnknownOlmPickleVersion
- NotEnoughRandom = ErrLibolmNotEnoughRandom
- OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall
- BadAccountKey = ErrLibolmBadAccountKey
- CorruptedPickle = ErrLibolmCorruptedPickle
- BadSessionKey = ErrLibolmBadSessionKey
- UnknownMessageIndex = ErrUnknownMessageIndex
- BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle
- InputBufferTooSmall = ErrInputToSmall
- NoKeyProvided = ErrNoKeyProvided
-
- NotEnoughGoRandom = ErrNotEnoughGoRandom
- InputNotJSONString = ErrInputNotJSONString
-
- ErrBadVersion = ErrUnknownJSONPickleVersion
- ErrWrongPickleVersion = ErrUnknownJSONPickleVersion
- ErrRatchetNotAvailable = ErrUnknownMessageIndex
+ NotEnoughRandom = errors.New("not enough entropy was supplied")
+ OutputBufferTooSmall = errors.New("supplied output buffer is too small")
+ BadMessageVersion = errors.New("the message version is unsupported")
+ BadMessageFormat = errors.New("the message couldn't be decoded")
+ BadMessageMAC = errors.New("the message couldn't be decrypted")
+ BadMessageKeyID = errors.New("the message references an unknown key ID")
+ InvalidBase64 = errors.New("the input base64 was invalid")
+ BadAccountKey = errors.New("the supplied account key is invalid")
+ UnknownPickleVersion = errors.New("the pickled object is too new")
+ CorruptedPickle = errors.New("the pickled object couldn't be decoded")
+ BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
+ UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
+ BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
+ BadSignature = errors.New("received message had a bad signature")
+ InputBufferTooSmall = errors.New("the input data was too small to be valid")
)
diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go
index 6b5b65fd..f5cecafc 100644
--- a/crypto/registergoolm.go
+++ b/crypto/registergoolm.go
@@ -2,10 +2,4 @@
package crypto
-import (
- "maunium.net/go/mautrix/crypto/goolm"
-)
-
-func init() {
- goolm.Register()
-}
+import _ "maunium.net/go/mautrix/crypto/goolm"
diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go
index ef78b6b5..ab388a5c 100644
--- a/crypto/registerlibolm.go
+++ b/crypto/registerlibolm.go
@@ -2,8 +2,4 @@
package crypto
-import "maunium.net/go/mautrix/crypto/libolm"
-
-func init() {
- libolm.Register()
-}
+import _ "maunium.net/go/mautrix/crypto/libolm"
diff --git a/crypto/sessions.go b/crypto/sessions.go
index ccc7b784..aecb0416 100644
--- a/crypto/sessions.go
+++ b/crypto/sessions.go
@@ -18,14 +18,8 @@ import (
)
var (
- ErrSessionNotShared = errors.New("session has not been shared")
- ErrSessionExpired = errors.New("session has expired")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- SessionNotShared = ErrSessionNotShared
- SessionExpired = ErrSessionExpired
+ SessionNotShared = errors.New("session has not been shared")
+ SessionExpired = errors.New("session has expired")
)
// OlmSessionList is a list of OlmSessions.
@@ -117,7 +111,6 @@ type InboundGroupSession struct {
MaxMessages int
IsScheduled bool
KeyBackupVersion id.KeyBackupVersion
- KeySource id.KeySource
id id.SessionID
}
@@ -137,7 +130,6 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: isScheduled,
- KeySource: id.KeySourceDirect,
}, nil
}
@@ -171,7 +163,7 @@ func (igs *InboundGroupSession) export() (*ExportedSession, error) {
ForwardingChains: igs.ForwardingChains,
RoomID: igs.RoomID,
SenderKey: igs.SenderKey,
- SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey},
+ SenderClaimedKeys: SenderClaimedKeys{},
SessionID: igs.ID(),
SessionKey: string(key),
}, nil
@@ -263,9 +255,9 @@ func (ogs *OutboundGroupSession) Expired() bool {
func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if !ogs.Shared {
- return nil, ErrSessionNotShared
+ return nil, SessionNotShared
} else if ogs.Expired() {
- return nil, ErrSessionExpired
+ return nil, SessionExpired
}
ogs.MessageCount++
ogs.LastEncryptedTime = time.Now()
diff --git a/crypto/sql_store.go b/crypto/sql_store.go
index 138cc557..b0625763 100644
--- a/crypto/sql_store.go
+++ b/crypto/sql_store.go
@@ -251,9 +251,8 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender
}
// GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key.
-// This will exclude sessions that have never been used to encrypt or decrypt a message.
func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) {
- err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1",
+ err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY created_at DESC LIMIT 1",
key, store.AccountID).Scan(&createdAt)
if errors.Is(err, sql.ErrNoRows) {
err = nil
@@ -346,23 +345,22 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou
Int("max_messages", session.MaxMessages).
Bool("is_scheduled", session.IsScheduled).
Stringer("key_backup_version", session.KeyBackupVersion).
- Stringer("key_source", session.KeySource).
Msg("Upserting megolm inbound group session")
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
- ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
- key_backup_version=excluded.key_backup_version, key_source=excluded.key_source
+ key_backup_version=excluded.key_backup_version
`,
session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages),
- session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID,
+ session.IsScheduled, session.KeyBackupVersion, store.AccountID,
)
return err
}
@@ -375,13 +373,12 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
- var keySource id.KeySource
err := store.DB.QueryRow(ctx, `
- SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
+ SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3`,
roomID, sessionID, store.AccountID,
- ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
+ ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
@@ -412,7 +409,6 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
- KeySource: keySource,
}, nil
}
@@ -537,8 +533,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
- var keySource id.KeySource
- err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
+ err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
if err != nil {
return nil, err
}
@@ -558,13 +553,12 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
- KeySource: keySource,
}, nil
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@@ -573,7 +567,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
store.AccountID,
)
@@ -582,7 +576,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
store.AccountID, version,
)
@@ -669,20 +663,6 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
- if eventID == "" && timestamp == 0 {
- var notOK bool
- const validateEmptyQuery = `
- SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3)
- `
- err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK)
- if notOK {
- zerolog.Ctx(ctx).Debug().
- Uint("message_index", index).
- Msg("Rejecting event without event ID and timestamp due to already knowing them")
- }
- return !notOK, err
- }
-
const validateQuery = `
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql
index 3709f1e5..00dd1387 100644
--- a/crypto/sql_store_upgrade/00-latest-revision.sql
+++ b/crypto/sql_store_upgrade/00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v19 (compatible with v15+): Latest revision
+-- v0 -> v17 (compatible with v15+): Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@@ -71,11 +71,8 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
key_backup_version TEXT NOT NULL DEFAULT '',
- key_source TEXT NOT NULL DEFAULT '',
PRIMARY KEY (account_id, session_id)
);
--- Useful index to find keys that need backing up
-CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL;
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
account_id TEXT,
diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql
deleted file mode 100644
index da26da0f..00000000
--- a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster
-CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL;
diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql
deleted file mode 100644
index f624222f..00000000
--- a/crypto/sql_store_upgrade/19-megolm-session-source.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v19 (compatible with v15+): Store megolm session source
-ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT '';
diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go
index 8691d032..e30925d9 100644
--- a/crypto/ssss/client.go
+++ b/crypto/ssss/client.go
@@ -95,22 +95,6 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even
return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
}
-// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it,
-// alongside the unencrypted metadata, on the server.
-func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error {
- if len(keys) == 0 {
- return ErrNoKeyGiven
- }
- encrypted := make(map[string]EncryptedKeyData, len(keys))
- for _, key := range keys {
- encrypted[key.ID] = key.Encrypt(eventType.Type, data)
- }
- return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{
- Encrypted: encrypted,
- Metadata: metadata,
- })
-}
-
// GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server.
func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) {
key, err = NewKey(passphrase)
diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go
index 78ebd8f3..aa22360a 100644
--- a/crypto/ssss/key.go
+++ b/crypto/ssss/key.go
@@ -7,8 +7,6 @@
package ssss
import (
- "crypto/hmac"
- "crypto/sha256"
"encoding/base64"
"fmt"
"strings"
@@ -59,12 +57,12 @@ func NewKey(passphrase string) (*Key, error) {
// We store a certain hash in the key metadata so that clients can check if the user entered the correct key.
ivBytes := random.Bytes(utils.AESCTRIVLength)
keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes)
- macBytes, err := keyData.calculateHash(ssssKey)
+ var err error
+ keyData.MAC, err = keyData.calculateHash(ssssKey)
if err != nil {
// This should never happen because we just generated the IV and key.
return nil, fmt.Errorf("failed to calculate hash: %w", err)
}
- keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes)
return &Key{
Key: ssssKey,
@@ -110,18 +108,12 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error)
return nil, err
}
- mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "="))
- if err != nil {
- return nil, err
- }
-
// derive the AES and HMAC keys for the requested event type using the SSSS key
aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType)
// compare the stored MAC with the one we calculated from the ciphertext
- h := hmac.New(sha256.New, hmacKey[:])
- h.Write(payload)
- if !hmac.Equal(h.Sum(nil), mac) {
+ calcMac := utils.HMACSHA256B64(payload, hmacKey)
+ if strings.TrimRight(data.MAC, "=") != calcMac {
return nil, ErrKeyDataMACMismatch
}
diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go
index 34775fa7..474c85d8 100644
--- a/crypto/ssss/meta.go
+++ b/crypto/ssss/meta.go
@@ -7,10 +7,7 @@
package ssss
import (
- "crypto/hmac"
- "crypto/sha256"
"encoding/base64"
- "errors"
"fmt"
"strings"
@@ -36,9 +33,7 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error)
ssssKey, err := kd.Passphrase.GetKey(passphrase)
if err != nil {
return nil, err
- }
- err = kd.verifyKey(ssssKey)
- if err != nil && !errors.Is(err, ErrUnverifiableKey) {
+ } else if err = kd.verifyKey(ssssKey); err != nil {
return nil, err
}
@@ -54,9 +49,7 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error
ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey)
if ssssKey == nil {
return nil, ErrInvalidRecoveryKey
- }
- err := kd.verifyKey(ssssKey)
- if err != nil && !errors.Is(err, ErrUnverifiableKey) {
+ } else if err := kd.verifyKey(ssssKey); err != nil {
return nil, err
}
@@ -64,28 +57,20 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error
ID: keyID,
Key: ssssKey,
Metadata: kd,
- }, err
+ }, nil
}
func (kd *KeyMetadata) verifyKey(key []byte) error {
- if kd.MAC == "" || kd.IV == "" {
- return ErrUnverifiableKey
- }
unpaddedMAC := strings.TrimRight(kd.MAC, "=")
expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength)
if len(unpaddedMAC) != expectedMACLength {
return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength)
}
- expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC)
- if err != nil {
- return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err)
- }
- calculatedMAC, err := kd.calculateHash(key)
+ hash, err := kd.calculateHash(key)
if err != nil {
return err
}
- // This doesn't really need to be constant time since it's fully local, but might as well be.
- if !hmac.Equal(expectedMAC, calculatedMAC) {
+ if unpaddedMAC != hash {
return ErrIncorrectSSSSKey
}
return nil
@@ -98,26 +83,23 @@ func (kd *KeyMetadata) VerifyKey(key []byte) bool {
// calculateHash calculates the hash used for checking if the key is entered correctly as described
// in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2
-func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) {
+func (kd *KeyMetadata) calculateHash(key []byte) (string, error) {
aesKey, hmacKey := utils.DeriveKeysSHA256(key, "")
unpaddedIV := strings.TrimRight(kd.IV, "=")
expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength)
- if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 {
- return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength)
+ if len(unpaddedIV) != expectedIVLength {
+ return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength)
}
- rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV)
- if err != nil {
- return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err)
- }
- // TODO log a warning for non-16 byte IVs?
- // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used.
- ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength])
- zeroes := make([]byte, utils.AESCTRKeyLength)
- encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes)
- h := hmac.New(sha256.New, hmacKey[:])
- h.Write(encryptedZeroes)
- return h.Sum(nil), nil
+ var ivBytes [utils.AESCTRIVLength]byte
+ _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV))
+ if err != nil {
+ return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err)
+ }
+
+ cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes)
+
+ return utils.HMACSHA256B64(cipher, hmacKey), nil
}
// PassphraseMetadata represents server-side metadata about a SSSS key passphrase.
diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go
index d59809c7..4f2ff378 100644
--- a/crypto/ssss/meta_test.go
+++ b/crypto/ssss/meta_test.go
@@ -8,10 +8,10 @@ package ssss_test
import (
"encoding/json"
+ "errors"
"testing"
"github.com/stretchr/testify/assert"
- "go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/crypto/ssss"
)
@@ -41,24 +41,10 @@ const key2Meta = `
}
`
-const key2MetaUnverified = `
-{
- "algorithm": "m.secret_storage.v1.aes-hmac-sha2"
-}
-`
-
-const key2MetaLongIV = `
-{
- "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
- "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=",
- "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
-}
-`
-
const key2MetaBrokenIV = `
{
"algorithm": "m.secret_storage.v1.aes-hmac-sha2",
- "iv": "MeowMeowMeow",
+ "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow",
"mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
}
`
@@ -84,11 +70,23 @@ func getKeyMeta(meta string) *ssss.KeyMetadata {
}
func getKey1() *ssss.Key {
- return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey))
+ km := getKeyMeta(key1Meta)
+ key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey)
+ if err != nil {
+ panic(err)
+ }
+ key.ID = key1ID
+ return key
}
func getKey2() *ssss.Key {
- return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey))
+ km := getKeyMeta(key2Meta)
+ key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
+ if err != nil {
+ panic(err)
+ }
+ key.ID = key2ID
+ return key
}
func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) {
@@ -107,33 +105,17 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) {
assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
}
-func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) {
- km := getKeyMeta(key2MetaLongIV)
- key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.NoError(t, err)
- assert.NotNil(t, key)
- assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
-}
-
-func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) {
- km := getKeyMeta(key2MetaUnverified)
- key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.ErrorIs(t, err, ssss.ErrUnverifiableKey)
- assert.NotNil(t, key)
- assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
-}
-
func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyRecoveryKey(key1ID, "foo")
- assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey)
+ assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey)
+ assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err)
assert.Nil(t, key)
}
@@ -148,27 +130,27 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) {
func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple")
- assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey)
+ assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) {
km := getKeyMeta(key2Meta)
key, err := km.VerifyPassphrase(key2ID, "hmm")
- assert.ErrorIs(t, err, ssss.ErrNoPassphrase)
+ assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) {
km := getKeyMeta(key2MetaBrokenIV)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata)
+ assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) {
km := getKeyMeta(key2MetaBrokenMAC)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata)
+ assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err)
assert.Nil(t, key)
}
diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go
index b7465d3e..345393b0 100644
--- a/crypto/ssss/types.go
+++ b/crypto/ssss/types.go
@@ -26,8 +26,7 @@ var (
ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm")
ErrIncorrectSSSSKey = errors.New("incorrect SSSS key")
ErrInvalidRecoveryKey = errors.New("invalid recovery key")
- ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata")
- ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata")
+ ErrCorruptedKeyMetadata = errors.New("corrupted key metadata")
)
// Algorithm is the identifier for an SSSS encryption algorithm.
@@ -58,7 +57,6 @@ type EncryptedKeyData struct {
type EncryptedAccountDataEventContent struct {
Encrypted map[string]EncryptedKeyData `json:"encrypted"`
- Metadata map[string]any `json:"com.beeper.metadata,omitzero"`
}
func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) {
diff --git a/crypto/store.go b/crypto/store.go
index 7620cf35..8b7c0a96 100644
--- a/crypto/store.go
+++ b/crypto/store.go
@@ -525,9 +525,6 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
}
val, ok := gs.MessageIndices[key]
if !ok {
- if eventID == "" && timestamp == 0 {
- return true, nil
- }
gs.MessageIndices[key] = messageIndexValue{
EventID: eventID,
Timestamp: timestamp,
diff --git a/crypto/store_test.go b/crypto/store_test.go
index 7a47243e..a7c4d75a 100644
--- a/crypto/store_test.go
+++ b/crypto/store_test.go
@@ -13,7 +13,6 @@ import (
"testing"
_ "github.com/mattn/go-sqlite3"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mau.fi/util/dbutil"
@@ -30,14 +29,22 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4
func getCryptoStores(t *testing.T) map[string]Store {
rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000")
- require.NoError(t, err, "Error opening raw database")
+ if err != nil {
+ t.Fatalf("Error opening db: %v", err)
+ }
db, err := dbutil.NewWithDB(rawDB, "sqlite3")
- require.NoError(t, err, "Error creating database wrapper")
+ if err != nil {
+ t.Fatalf("Error opening db: %v", err)
+ }
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
- err = sqlStore.DB.Upgrade(context.TODO())
- require.NoError(t, err, "Error upgrading database")
+ if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
+ t.Fatalf("Error creating tables: %v", err)
+ }
gobStore := NewMemoryStore(nil)
+ if err != nil {
+ t.Fatalf("Error creating Gob store: %v", err)
+ }
return map[string]Store{
"sql": sqlStore,
@@ -49,10 +56,9 @@ func TestPutNextBatch(t *testing.T) {
stores := getCryptoStores(t)
store := stores["sql"].(*SQLCryptoStore)
store.PutNextBatch(context.Background(), "batch1")
-
- batch, err := store.GetNextBatch(context.Background())
- require.NoError(t, err, "Error retrieving next batch")
- assert.Equal(t, "batch1", batch)
+ if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" {
+ t.Errorf("Expected batch1, got %v", batch)
+ }
}
func TestPutAccount(t *testing.T) {
@@ -62,9 +68,15 @@ func TestPutAccount(t *testing.T) {
acc := NewOlmAccount()
store.PutAccount(context.TODO(), acc)
retrieved, err := store.GetAccount(context.TODO())
- require.NoError(t, err, "Error retrieving account")
- assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match")
- assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match")
+ if err != nil {
+ t.Fatalf("Error retrieving account: %v", err)
+ }
+ if acc.IdentityKey() != retrieved.IdentityKey() {
+ t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey())
+ }
+ if acc.SigningKey() != retrieved.SigningKey() {
+ t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey())
+ }
})
}
}
@@ -74,36 +86,18 @@ func TestValidateMessageIndex(t *testing.T) {
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
acc := NewOlmAccount()
-
- // Validating without event ID and timestamp before we have them should work
- ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0)
- require.NoError(t, err, "Error validating message index")
- assert.True(t, ok, "First message validation should be valid")
-
- // First message should validate successfully
- ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000)
- require.NoError(t, err, "Error validating message index")
- assert.True(t, ok, "First message validation should be valid")
-
- // Edit the timestamp and ensure validate fails
- ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001)
- require.NoError(t, err, "Error validating message index after timestamp change")
- assert.False(t, ok, "First message validation should fail after timestamp change")
-
- // Edit the event ID and ensure validate fails
- ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000)
- require.NoError(t, err, "Error validating message index after event ID change")
- assert.False(t, ok, "First message validation should fail after event ID change")
-
- // Validate again with the original parameters and ensure that it still passes
- ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000)
- require.NoError(t, err, "Error validating message index")
- assert.True(t, ok, "First message validation should be valid")
-
- // Validating without event ID and timestamp must fail if we already know them
- ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0)
- require.NoError(t, err, "Error validating message index")
- assert.False(t, ok, "First message validation should be invalid")
+ if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok {
+ t.Error("First message not validated successfully")
+ }
+ if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok {
+ t.Error("First message validated successfully after changing timestamp")
+ }
+ if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok {
+ t.Error("First message validated successfully after changing event ID")
+ }
+ if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok {
+ t.Error("First message not validated successfully for a second time")
+ }
})
}
}
@@ -112,26 +106,43 @@ func TestStoreOlmSession(t *testing.T) {
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
- require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it")
-
+ if store.HasSession(context.TODO(), olmSessID) {
+ t.Error("Found Olm session before inserting it")
+ }
olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test"))
- require.NoError(t, err, "Error creating internal Olm session")
+ if err != nil {
+ t.Fatalf("Error creating internal Olm session: %v", err)
+ }
olmSess := OlmSession{
id: olmSessID,
Internal: olmInternal,
}
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
- require.NoError(t, err, "Error storing Olm session")
- assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it")
+ if err != nil {
+ t.Errorf("Error storing Olm session: %v", err)
+ }
+ if !store.HasSession(context.TODO(), olmSessID) {
+ t.Error("Not found Olm session after inserting it")
+ }
retrieved, err := store.GetLatestSession(context.TODO(), olmSessID)
- require.NoError(t, err, "Error retrieving Olm session")
- assert.EqualValues(t, olmSessID, retrieved.ID())
+ if err != nil {
+ t.Errorf("Failed retrieving Olm session: %v", err)
+ }
+
+ if retrieved.ID() != olmSessID {
+ t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID())
+ }
pickled, err := retrieved.Internal.Pickle([]byte("test"))
- require.NoError(t, err, "Error pickling Olm session")
- assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original")
+ if err != nil {
+ t.Fatalf("Error pickling Olm session: %v", err)
+ }
+
+ if string(pickled) != olmPickled {
+ t.Error("Pickled Olm session does not match original")
+ }
})
}
}
@@ -143,7 +154,9 @@ func TestStoreMegolmSession(t *testing.T) {
acc := NewOlmAccount()
internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test"))
- require.NoError(t, err, "Error creating internal inbound group session")
+ if err != nil {
+ t.Fatalf("Error creating internal inbound group session: %v", err)
+ }
igs := &InboundGroupSession{
Internal: internal,
@@ -153,14 +166,20 @@ func TestStoreMegolmSession(t *testing.T) {
}
err = store.PutGroupSession(context.TODO(), igs)
- require.NoError(t, err, "Error storing inbound group session")
+ if err != nil {
+ t.Errorf("Error storing inbound group session: %v", err)
+ }
retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID())
- require.NoError(t, err, "Error retrieving inbound group session")
+ if err != nil {
+ t.Errorf("Error retrieving inbound group session: %v", err)
+ }
- pickled, err := retrieved.Internal.Pickle([]byte("test"))
- require.NoError(t, err, "Error pickling inbound group session")
- assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original")
+ if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil {
+ t.Fatalf("Error pickling inbound group session: %v", err)
+ } else if string(pickled) != groupSession {
+ t.Error("Pickled inbound group session does not match original")
+ }
})
}
}
@@ -170,24 +189,40 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
sess, err := store.GetOutboundGroupSession(context.TODO(), "room1")
- require.NoError(t, err, "Error retrieving outbound session")
- require.Nil(t, sess, "Got outbound session before inserting")
+ if sess != nil {
+ t.Error("Got outbound session before inserting")
+ }
+ if err != nil {
+ t.Errorf("Error retrieving outbound session: %v", err)
+ }
outbound, err := NewOutboundGroupSession("room1", nil)
require.NoError(t, err)
err = store.AddOutboundGroupSession(context.TODO(), outbound)
- require.NoError(t, err, "Error inserting outbound session")
+ if err != nil {
+ t.Errorf("Error inserting outbound session: %v", err)
+ }
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
- require.NoError(t, err, "Error retrieving outbound session")
- assert.NotNil(t, sess, "Did not get outbound session after inserting")
+ if sess == nil {
+ t.Error("Did not get outbound session after inserting")
+ }
+ if err != nil {
+ t.Errorf("Error retrieving outbound session: %v", err)
+ }
err = store.RemoveOutboundGroupSession(context.TODO(), "room1")
- require.NoError(t, err, "Error deleting outbound session")
+ if err != nil {
+ t.Errorf("Error deleting outbound session: %v", err)
+ }
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
- require.NoError(t, err, "Error retrieving outbound session after deletion")
- assert.Nil(t, sess, "Got outbound session after deleting")
+ if sess != nil {
+ t.Error("Got outbound session after deleting")
+ }
+ if err != nil {
+ t.Errorf("Error retrieving outbound session: %v", err)
+ }
})
}
}
@@ -209,41 +244,58 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) {
t.Run(storeName, func(t *testing.T) {
device := resetDevice()
err := store.PutDevice(context.TODO(), "user1", device)
- require.NoError(t, err, "Error storing device")
+ if err != nil {
+ t.Errorf("Error storing devices: %v", err)
+ }
shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- require.NoError(t, err, "Error checking if outbound group session is shared")
- assert.False(t, shared, "Outbound group session should not be shared initially")
+ if err != nil {
+ t.Errorf("Error checking if outbound group session is shared: %v", err)
+ } else if shared {
+ t.Errorf("Outbound group session shared when it shouldn't")
+ }
err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- require.NoError(t, err, "Error marking outbound group session as shared")
+ if err != nil {
+ t.Errorf("Error marking outbound group session as shared: %v", err)
+ }
shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- require.NoError(t, err, "Error checking if outbound group session is shared")
- assert.True(t, shared, "Outbound group session should be shared after marking it as such")
+ if err != nil {
+ t.Errorf("Error checking if outbound group session is shared: %v", err)
+ } else if !shared {
+ t.Errorf("Outbound group session not shared when it should")
+ }
device = resetDevice()
err = store.PutDevice(context.TODO(), "user1", device)
- require.NoError(t, err, "Error storing device after resetting")
+ if err != nil {
+ t.Errorf("Error storing devices: %v", err)
+ }
shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1")
- require.NoError(t, err, "Error checking if outbound group session is shared")
- assert.False(t, shared, "Outbound group session should not be shared after resetting device")
+ if err != nil {
+ t.Errorf("Error checking if outbound group session is shared: %v", err)
+ } else if shared {
+ t.Errorf("Outbound group session shared when it shouldn't")
+ }
})
}
}
func TestStoreDevices(t *testing.T) {
- devicesToCreate := 17
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
outdated, err := store.GetOutdatedTrackedUsers(context.TODO())
- require.NoError(t, err, "Error filtering tracked users")
- assert.Empty(t, outdated, "Expected no outdated tracked users initially")
-
+ if err != nil {
+ t.Errorf("Error filtering tracked users: %v", err)
+ }
+ if len(outdated) > 0 {
+ t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
+ }
deviceMap := make(map[id.DeviceID]*id.Device)
- for i := 0; i < devicesToCreate; i++ {
+ for i := 0; i < 17; i++ {
iStr := strconv.Itoa(i)
acc := NewOlmAccount()
deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{
@@ -254,33 +306,59 @@ func TestStoreDevices(t *testing.T) {
}
}
err = store.PutDevices(context.TODO(), "user1", deviceMap)
- require.NoError(t, err, "Error storing devices")
+ if err != nil {
+ t.Errorf("Error storing devices: %v", err)
+ }
devs, err := store.GetDevices(context.TODO(), "user1")
- require.NoError(t, err, "Error getting devices")
- assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate)
- assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices")
+ if err != nil {
+ t.Errorf("Error getting devices: %v", err)
+ }
+ if len(devs) != 17 {
+ t.Errorf("Stored 17 devices, got back %v", len(devs))
+ }
+ if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey {
+ t.Errorf("First device identity key does not match")
+ }
+ if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey {
+ t.Errorf("Last device identity key does not match")
+ }
filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"})
- require.NoError(t, err, "Error filtering tracked users")
- assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter")
+ if err != nil {
+ t.Errorf("Error filtering tracked users: %v", err)
+ } else if len(filtered) != 1 || filtered[0] != "user1" {
+ t.Errorf("Expected to get 'user1' from filter, got %v", filtered)
+ }
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
- require.NoError(t, err, "Error filtering tracked users")
- assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage")
-
+ if err != nil {
+ t.Errorf("Error filtering tracked users: %v", err)
+ }
+ if len(outdated) > 0 {
+ t.Errorf("Got %d outdated tracked users when expected none", len(outdated))
+ }
err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"})
- require.NoError(t, err, "Error marking tracked users outdated")
-
+ if err != nil {
+ t.Errorf("Error marking tracked users outdated: %v", err)
+ }
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
- require.NoError(t, err, "Error filtering tracked users")
- assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated")
-
+ if err != nil {
+ t.Errorf("Error filtering tracked users: %v", err)
+ }
+ if len(outdated) != 1 || outdated[0] != id.UserID("user1") {
+ t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated)
+ }
err = store.PutDevices(context.TODO(), "user1", deviceMap)
- require.NoError(t, err, "Error storing devices again")
-
+ if err != nil {
+ t.Errorf("Error storing devices: %v", err)
+ }
outdated, err = store.GetOutdatedTrackedUsers(context.TODO())
- require.NoError(t, err, "Error filtering tracked users")
- assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices")
+ if err != nil {
+ t.Errorf("Error filtering tracked users: %v", err)
+ }
+ if len(outdated) > 0 {
+ t.Errorf("Got outdated tracked users %v when expected none", outdated)
+ }
})
}
}
@@ -291,11 +369,16 @@ func TestStoreSecrets(t *testing.T) {
t.Run(storeName, func(t *testing.T) {
storedSecret := "trustno1"
err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret)
- require.NoError(t, err, "Error storing secret")
+ if err != nil {
+ t.Errorf("Error storing secret: %v", err)
+ }
secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1)
- require.NoError(t, err, "Error retrieving secret")
- assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret")
+ if err != nil {
+ t.Errorf("Error storing secret: %v", err)
+ } else if secret != storedSecret {
+ t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret)
+ }
})
}
}
diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go
index b12fd9e2..c4f01a68 100644
--- a/crypto/utils/utils_test.go
+++ b/crypto/utils/utils_test.go
@@ -9,9 +9,6 @@ package utils
import (
"encoding/base64"
"testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
)
func TestAES256Ctr(t *testing.T) {
@@ -19,7 +16,9 @@ func TestAES256Ctr(t *testing.T) {
key, iv := GenAttachmentA256CTR()
enc := XorA256CTR([]byte(expected), key, iv)
dec := XorA256CTR(enc, key, iv)
- assert.EqualValues(t, expected, dec, "Decrypted text should match original")
+ if string(dec) != expected {
+ t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec))
+ }
var key2 [AESCTRKeyLength]byte
var iv2 [AESCTRIVLength]byte
@@ -30,7 +29,9 @@ func TestAES256Ctr(t *testing.T) {
iv2[i] = byte(i) + 32
}
dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2)
- assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original")
+ if string(dec2) != expected {
+ t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2))
+ }
}
func TestPBKDF(t *testing.T) {
@@ -41,7 +42,9 @@ func TestPBKDF(t *testing.T) {
key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256)
expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E="
keyB64 := base64.StdEncoding.EncodeToString([]byte(key))
- assert.Equal(t, expected, keyB64)
+ if keyB64 != expected {
+ t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64)
+ }
}
func TestDecodeSSSSKey(t *testing.T) {
@@ -50,10 +53,13 @@ func TestDecodeSSSSKey(t *testing.T) {
expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw="
decodedB64 := base64.StdEncoding.EncodeToString(decoded[:])
- assert.Equal(t, expected, decodedB64)
+ if expected != decodedB64 {
+ t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64)
+ }
- encoded := EncodeBase58RecoveryKey(decoded)
- assert.Equal(t, recoveryKey, encoded)
+ if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey {
+ t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded)
+ }
}
func TestKeyDerivationAndHMAC(t *testing.T) {
@@ -63,11 +69,15 @@ func TestKeyDerivationAndHMAC(t *testing.T) {
aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master")
ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=")
- require.NoError(t, err)
+ if err != nil {
+ t.Error(err)
+ }
calcMac := HMACSHA256B64(ciphertextBytes, hmacKey)
expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E"
- assert.Equal(t, expectedMac, calcMac)
+ if calcMac != expectedMac {
+ t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac)
+ }
var ivBytes [AESCTRIVLength]byte
decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==")
@@ -75,5 +85,7 @@ func TestKeyDerivationAndHMAC(t *testing.T) {
decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes))
expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s="
- assert.Equal(t, expectedDec, decrypted)
+ if expectedDec != decrypted {
+ t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted)
+ }
}
diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go
new file mode 100644
index 00000000..b6bf3d2c
--- /dev/null
+++ b/crypto/verificationhelper/mockserver_test.go
@@ -0,0 +1,255 @@
+// Copyright (c) 2024 Sumner Evans
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package verificationhelper_test
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/gorilla/mux"
+ "github.com/rs/zerolog/log" // zerolog-allow-global-log
+ "github.com/stretchr/testify/require"
+ "go.mau.fi/util/random"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/crypto"
+ "maunium.net/go/mautrix/crypto/cryptohelper"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow
+// testing of the interactive verification process.
+type mockServer struct {
+ *httptest.Server
+
+ AccessTokenToUserID map[string]id.UserID
+ DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
+ AccountData map[id.UserID]map[event.Type]json.RawMessage
+ DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
+ MasterKeys map[id.UserID]mautrix.CrossSigningKeys
+ SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+ UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
+}
+
+func DecodeVarsMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ vars := mux.Vars(r)
+ var err error
+ for k, v := range vars {
+ vars[k], err = url.PathUnescape(v)
+ if err != nil {
+ panic(err)
+ }
+ }
+ next.ServeHTTP(w, r)
+ })
+}
+
+func createMockServer(t *testing.T) *mockServer {
+ t.Helper()
+
+ server := mockServer{
+ AccessTokenToUserID: map[string]id.UserID{},
+ DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
+ AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ }
+
+ router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath()
+ router.Use(DecodeVarsMiddleware)
+ router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost)
+ router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost)
+ router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut)
+ router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut)
+ router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost)
+ router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost)
+ router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost)
+
+ server.Server = httptest.NewServer(router)
+ return &server
+}
+
+func (ms *mockServer) getUserID(r *http.Request) id.UserID {
+ authHeader := r.Header.Get("Authorization")
+ authHeader = strings.TrimPrefix(authHeader, "Bearer ")
+ userID, ok := ms.AccessTokenToUserID[authHeader]
+ if !ok {
+ panic("no user ID found for access token " + authHeader)
+ }
+ return userID
+}
+
+func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
+ w.Write([]byte("{}"))
+}
+
+func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) {
+ var loginReq mautrix.ReqLogin
+ json.NewDecoder(r.Body).Decode(&loginReq)
+
+ deviceID := loginReq.DeviceID
+ if deviceID == "" {
+ deviceID = id.DeviceID(random.String(10))
+ }
+
+ accessToken := random.String(30)
+ userID := id.UserID(loginReq.Identifier.User)
+ s.AccessTokenToUserID[accessToken] = userID
+
+ json.NewEncoder(w).Encode(&mautrix.RespLogin{
+ AccessToken: accessToken,
+ DeviceID: deviceID,
+ UserID: userID,
+ })
+}
+
+func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
+ vars := mux.Vars(r)
+ var req mautrix.ReqSendToDevice
+ json.NewDecoder(r.Body).Decode(&req)
+ evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType}
+
+ for user, devices := range req.Messages {
+ for device, content := range devices {
+ if _, ok := s.DeviceInbox[user]; !ok {
+ s.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
+ }
+ content.ParseRaw(evtType)
+ s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{
+ Sender: s.getUserID(r),
+ Type: evtType,
+ Content: *content,
+ })
+ }
+ }
+ s.emptyResp(w, r)
+}
+
+func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
+ vars := mux.Vars(r)
+ userID := id.UserID(vars["userID"])
+ eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType}
+
+ jsonData, _ := io.ReadAll(r.Body)
+ if _, ok := s.AccountData[userID]; !ok {
+ s.AccountData[userID] = map[event.Type]json.RawMessage{}
+ }
+ s.AccountData[userID][eventType] = json.RawMessage(jsonData)
+ s.emptyResp(w, r)
+}
+
+func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqQueryKeys
+ json.NewDecoder(r.Body).Decode(&req)
+ resp := mautrix.RespQueryKeys{
+ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
+ DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
+ }
+ for user := range req.DeviceKeys {
+ resp.MasterKeys[user] = s.MasterKeys[user]
+ resp.UserSigningKeys[user] = s.UserSigningKeys[user]
+ resp.SelfSigningKeys[user] = s.SelfSigningKeys[user]
+ resp.DeviceKeys[user] = s.DeviceKeys[user]
+ }
+ json.NewEncoder(w).Encode(&resp)
+}
+
+func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.ReqUploadKeys
+ json.NewDecoder(r.Body).Decode(&req)
+
+ userID := s.getUserID(r)
+ if _, ok := s.DeviceKeys[userID]; !ok {
+ s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
+ }
+ s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys
+
+ json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{
+ OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50},
+ })
+}
+
+func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
+ var req mautrix.UploadCrossSigningKeysReq
+ json.NewDecoder(r.Body).Decode(&req)
+
+ userID := s.getUserID(r)
+ s.MasterKeys[userID] = req.Master
+ s.SelfSigningKeys[userID] = req.SelfSigning
+ s.UserSigningKeys[userID] = req.UserSigning
+
+ s.emptyResp(w, r)
+}
+
+func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
+ t.Helper()
+ client, err := mautrix.NewClient(ms.URL, "", "")
+ require.NoError(t, err)
+ client.StateStore = mautrix.NewMemoryStateStore()
+
+ _, err = client.Login(ctx, &mautrix.ReqLogin{
+ Type: mautrix.AuthTypePassword,
+ Identifier: mautrix.UserIdentifier{
+ Type: mautrix.IdentifierTypeUser,
+ User: userID.String(),
+ },
+ DeviceID: deviceID,
+ Password: "password",
+ StoreCredentials: true,
+ })
+ require.NoError(t, err)
+
+ cryptoStore := crypto.NewMemoryStore(nil)
+ cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore)
+ require.NoError(t, err)
+ client.Crypto = cryptoHelper
+
+ err = cryptoHelper.Init(ctx)
+ require.NoError(t, err)
+
+ machineLog := log.Logger.With().
+ Stringer("my_user_id", userID).
+ Stringer("my_device_id", deviceID).
+ Logger()
+ cryptoHelper.Machine().Log = &machineLog
+
+ err = cryptoHelper.Machine().ShareKeys(ctx, 50)
+ require.NoError(t, err)
+
+ return client, cryptoStore
+}
+
+func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) {
+ t.Helper()
+
+ for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
+ client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
+ ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
+ }
+}
+
+func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) {
+ err := cryptoStore.PutDevice(ctx, userID, &id.Device{
+ UserID: userID,
+ DeviceID: deviceID,
+ })
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go
index e6392c79..1313a613 100644
--- a/crypto/verificationhelper/sas.go
+++ b/crypto/verificationhelper/sas.go
@@ -695,7 +695,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific
// Verify the MAC for each key
var theirDevice *id.Device
for keyID, mac := range macEvt.MAC {
- log.Info().Stringer("key_id", keyID).Msg("Received MAC for key")
+ log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key")
alg, kID := keyID.Parse()
if alg != id.KeyAlgorithmEd25519 {
diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go
index 0a781c16..9d843ea8 100644
--- a/crypto/verificationhelper/verificationhelper.go
+++ b/crypto/verificationhelper/verificationhelper.go
@@ -848,7 +848,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif
// here, since the start command for scanning and showing QR codes
// should be of type m.reciprocate.v1.
log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event")
- vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unknown method %s", txn.StartEventContent.Method)
+ vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", txn.StartEventContent.Method))
}
}
diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go
index 5e3f146b..aace2230 100644
--- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go
+++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go
@@ -32,6 +32,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -50,10 +51,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, bobUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@@ -82,7 +83,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device detected that its QR code
// was scanned.
@@ -97,7 +98,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = sendingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
} else { // receiving scans QR
// Emulate scanning the QR code shown by the sending device on
// the receiving device.
@@ -120,7 +121,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device detected that its QR code was
// scanned.
@@ -135,7 +136,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = receivingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
}
// Ensure that both devices have marked the verification as done.
diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go
index ea918cd4..937cc414 100644
--- a/crypto/verificationhelper/verificationhelper_qr_self_test.go
+++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go
@@ -36,6 +36,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -61,7 +62,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
if tc.expectedAcceptError != "" {
@@ -71,7 +72,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) {
require.NoError(t, err)
}
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@@ -134,6 +135,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -150,10 +152,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID)
require.NotNil(t, receivingShownQRCode)
@@ -182,7 +184,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device detected that its QR code
// was scanned.
@@ -197,7 +199,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = sendingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
} else { // receiving scans QR
// Emulate scanning the QR code shown by the sending device on
// the receiving device.
@@ -220,7 +222,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
// Handle the start and done events on the receiving client and
// confirm the scan.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device detected that its QR code was
// scanned.
@@ -235,7 +237,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) {
doneEvt = receivingInbox[0].Content.AsVerificationDone()
assert.Equal(t, txnID, doneEvt.TransactionID)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
}
// Ensure that both devices have marked the verification as done.
@@ -249,6 +251,7 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -260,10 +263,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes()
sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes()
@@ -307,6 +310,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -323,10 +327,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes()
sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes()
@@ -344,7 +348,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// Ensure that the receiving device received a cancellation.
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 1)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
cancellation := receivingCallbacks.GetVerificationCancellation(txnID)
require.NotNil(t, cancellation)
assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code)
@@ -358,7 +362,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) {
// Ensure that the sending device received a cancellation.
sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID]
assert.Len(t, sendingInbox, 1)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
cancellation := sendingCallbacks.GetVerificationCancellation(txnID)
require.NotNil(t, cancellation)
assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code)
diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go
index 283eca84..5747ac34 100644
--- a/crypto/verificationhelper/verificationhelper_sas_test.go
+++ b/crypto/verificationhelper/verificationhelper_sas_test.go
@@ -36,6 +36,7 @@ func TestVerification_SAS(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -59,10 +60,10 @@ func TestVerification_SAS(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Test that the start event is correct
var startEvt *event.VerificationStartEventContent
@@ -101,7 +102,7 @@ func TestVerification_SAS(t *testing.T) {
if tc.sendingStartsSAS {
// Process the verification start event on the receiving
// device.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Receiving device sent the accept event to the sending device
sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID]
@@ -109,7 +110,7 @@ func TestVerification_SAS(t *testing.T) {
acceptEvt = sendingInbox[0].Content.AsVerificationAccept()
} else {
// Process the verification start event on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Sending device sent the accept event to the receiving device
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
@@ -128,7 +129,7 @@ func TestVerification_SAS(t *testing.T) {
var firstKeyEvt *event.VerificationKeyEventContent
if tc.sendingStartsSAS {
// Process the verification accept event on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Sending device sends first key event to the receiving
// device.
@@ -138,7 +139,7 @@ func TestVerification_SAS(t *testing.T) {
} else {
// Process the verification accept event on the receiving
// device.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Receiving device sends first key event to the sending
// device.
@@ -154,7 +155,7 @@ func TestVerification_SAS(t *testing.T) {
var secondKeyEvt *event.VerificationKeyEventContent
if tc.sendingStartsSAS {
// Process the first key event on the receiving device.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Receiving device sends second key event to the sending
// device.
@@ -169,7 +170,7 @@ func TestVerification_SAS(t *testing.T) {
assert.Len(t, descriptions, 7)
} else {
// Process the first key event on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Sending device sends second key event to the receiving
// device.
@@ -190,10 +191,10 @@ func TestVerification_SAS(t *testing.T) {
// Ensure that the SAS codes are the same.
if tc.sendingStartsSAS {
// Process the second key event on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
} else {
// Process the second key event on the receiving device.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
}
assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID))
sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID)
@@ -273,10 +274,10 @@ func TestVerification_SAS(t *testing.T) {
// Test the transaction is done on both sides. We have to dispatch
// twice to process and drain all of the events.
- ts.DispatchToDevice(t, ctx, sendingClient)
- ts.DispatchToDevice(t, ctx, receivingClient)
- ts.DispatchToDevice(t, ctx, sendingClient)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
assert.True(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
})
@@ -287,6 +288,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
var err error
@@ -303,10 +305,10 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// event on the sending device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
err = sendingHelper.StartSAS(ctx, txnID)
require.NoError(t, err)
@@ -323,7 +325,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID)
// Process the start event from the receiving client to the sending client.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 2)
assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID)
@@ -331,13 +333,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// Process the rest of the events until we need to confirm the SAS.
for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 {
- ts.DispatchToDevice(t, ctx, receivingClient)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
}
// Confirm the SAS only the receiving device.
receivingHelper.ConfirmSAS(ctx, txnID)
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Verification is not done until both devices confirm the SAS.
assert.False(t, sendingCallbacks.IsVerificationDone(txnID))
@@ -348,13 +350,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) {
// Dispatching the events to the receiving device should get us to the done
// state on the receiving device.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
assert.False(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
// Dispatching the events to the sending client should get us to the done
// state on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
assert.True(t, sendingCallbacks.IsVerificationDone(txnID))
assert.True(t, receivingCallbacks.IsVerificationDone(txnID))
}
diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go
index ce5ec5b4..b4c21c18 100644
--- a/crypto/verificationhelper/verificationhelper_test.go
+++ b/crypto/verificationhelper/verificationhelper_test.go
@@ -19,7 +19,6 @@ import (
"maunium.net/go/mautrix/crypto/verificationhelper"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
- "maunium.net/go/mautrix/mockserver"
)
var aliceUserID = id.UserID("@alice:example.org")
@@ -32,19 +31,9 @@ func init() {
zerolog.DefaultContextLogger = &log.Logger
}
-func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) {
- err := cryptoStore.PutDevice(ctx, userID, &id.Device{
- UserID: userID,
- DeviceID: deviceID,
- })
- if err != nil {
- panic(err)
- }
-}
-
-func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
+func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
t.Helper()
- ts = mockserver.Create(t)
+ ts = createMockServer(t)
sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID)
sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
@@ -58,9 +47,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserv
return
}
-func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
+func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) {
t.Helper()
- ts = mockserver.Create(t)
+ ts = createMockServer(t)
sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID)
sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine()
@@ -127,7 +116,8 @@ func TestVerification_Start(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
- ts := mockserver.Create(t)
+ ts := createMockServer(t)
+ defer ts.Close()
client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID)
addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID)
@@ -176,6 +166,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
for _, sendingCancels := range []bool{true, false} {
t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) {
ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID)
@@ -195,13 +186,13 @@ func TestVerification_StartThenCancel(t *testing.T) {
receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID]
assert.Len(t, receivingInbox, 1)
assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Process the request event on the bystander device.
bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID]
assert.Len(t, bystanderInbox, 1)
assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID)
- ts.DispatchToDevice(t, ctx, bystanderClient)
+ ts.dispatchToDevice(t, ctx, bystanderClient)
// Cancel the verification request.
var cancelEvt *event.VerificationCancelEventContent
@@ -240,7 +231,7 @@ func TestVerification_StartThenCancel(t *testing.T) {
if !sendingCancels {
// Process the cancellation event on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Ensure that the cancellation event was sent to the bystander device.
assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1)
@@ -256,7 +247,8 @@ func TestVerification_StartThenCancel(t *testing.T) {
func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
- ts := mockserver.Create(t)
+ ts := createMockServer(t)
+ defer ts.Close()
sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID)
receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID)
@@ -282,7 +274,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, txnID)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiver ignored the request because it
// doesn't support any of the verification methods in the
@@ -322,6 +314,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
assert.NoError(t, err)
@@ -340,7 +333,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
require.NoError(t, err)
// Process the verification request on the receiving device.
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Ensure that the receiving device received a verification
// request with the correct transaction ID.
@@ -380,7 +373,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
// Receive the m.key.verification.ready event on the sending
// device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// Ensure that the sending device got a notification about the
// transaction being ready.
@@ -409,6 +402,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
nonParticipatingDeviceID1 := id.DeviceID("non-participating1")
@@ -425,12 +419,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
// the receiving device.
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
// Receive the m.key.verification.ready event on the sending device.
- ts.DispatchToDevice(t, ctx, sendingClient)
+ ts.dispatchToDevice(t, ctx, sendingClient)
// The sending and receiving devices should not have any cancellation
// events in their inboxes.
@@ -450,6 +444,7 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) {
func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
_, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
_, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
@@ -457,7 +452,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
txnID, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)
err = receivingHelper.AcceptVerification(ctx, txnID)
@@ -477,6 +472,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) {
func TestVerification_CancelOnDoubleStart(t *testing.T) {
ctx := log.Logger.WithContext(context.TODO())
ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx)
+ defer ts.Close()
sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine)
_, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
@@ -485,15 +481,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) {
// Send and accept the first verification request.
txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
err = receivingHelper.AcceptVerification(ctx, txnID1)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event
+ ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event
// Send a second verification request
txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID)
require.NoError(t, err)
- ts.DispatchToDevice(t, ctx, receivingClient)
+ ts.dispatchToDevice(t, ctx, receivingClient)
// Ensure that the sending device received a cancellation event for both of
// the ongoing transactions.
@@ -511,7 +507,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) {
assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1))
assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2))
- ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events
+ ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events
assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1))
assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2))
}
diff --git a/error.go b/error.go
index 4711b3dc..6f4880df 100644
--- a/error.go
+++ b/error.go
@@ -13,7 +13,6 @@ import (
"net/http"
"go.mau.fi/util/exhttp"
- "go.mau.fi/util/exmaps"
"golang.org/x/exp/maps"
)
@@ -67,8 +66,6 @@ var (
MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"}
// The client specified a parameter that has the wrong value.
MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest}
- // The client specified a room key backup version that is not the current room key backup version for the user.
- MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden}
MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"}
MBadStatus = RespError{ErrCode: "M_BAD_STATUS"}
@@ -82,13 +79,6 @@ var (
var (
ErrClientIsNil = errors.New("client is nil")
ErrClientHasNoHomeserver = errors.New("client has no homeserver set")
-
- ErrResponseTooLong = errors.New("response content length too long")
- ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body")
-
- // Special error that indicates we should retry canceled contexts. Note that on it's own this
- // is useless, the context itself must also be replaced.
- ErrContextCancelRetry = errors.New("retry canceled context")
)
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
@@ -140,10 +130,7 @@ type RespError struct {
Err string
ExtraData map[string]any
- StatusCode int
- ExtraHeader map[string]string
-
- CanRetry bool
+ StatusCode int
}
func (e *RespError) UnmarshalJSON(data []byte) error {
@@ -153,17 +140,16 @@ func (e *RespError) UnmarshalJSON(data []byte) error {
}
e.ErrCode, _ = e.ExtraData["errcode"].(string)
e.Err, _ = e.ExtraData["error"].(string)
- e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool)
return nil
}
func (e *RespError) MarshalJSON() ([]byte, error) {
- data := exmaps.NonNilClone(e.ExtraData)
+ data := maps.Clone(e.ExtraData)
+ if data == nil {
+ data = make(map[string]any)
+ }
data["errcode"] = e.ErrCode
data["error"] = e.Err
- if e.CanRetry {
- data["com.beeper.can_retry"] = e.CanRetry
- }
return json.Marshal(data)
}
@@ -175,9 +161,6 @@ func (e RespError) Write(w http.ResponseWriter) {
if statusCode == 0 {
statusCode = http.StatusInternalServerError
}
- for key, value := range e.ExtraHeader {
- w.Header().Set(key, value)
- }
exhttp.WriteJSONResponse(w, statusCode, &e)
}
@@ -194,29 +177,6 @@ func (e RespError) WithStatus(status int) RespError {
return e
}
-func (e RespError) WithCanRetry(canRetry bool) RespError {
- e.CanRetry = canRetry
- return e
-}
-
-func (e RespError) WithExtraData(extraData map[string]any) RespError {
- e.ExtraData = exmaps.NonNilClone(e.ExtraData)
- maps.Copy(e.ExtraData, extraData)
- return e
-}
-
-func (e RespError) WithExtraHeader(key, value string) RespError {
- e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader)
- e.ExtraHeader[key] = value
- return e
-}
-
-func (e RespError) WithExtraHeaders(headers map[string]string) RespError {
- e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader)
- maps.Copy(e.ExtraHeader, headers)
- return e
-}
-
// Error returns the errcode and error message.
func (e RespError) Error() string {
return e.ErrCode + ": " + e.Err
diff --git a/event/accountdata.go b/event/accountdata.go
index 223919a1..30ca35a2 100644
--- a/event/accountdata.go
+++ b/event/accountdata.go
@@ -105,15 +105,3 @@ func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time {
}
return time.Time{}
}
-
-func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration {
- ts := bmec.GetMutedUntilTime()
- now := time.Now()
- if ts.Before(now) {
- return 0
- } else if ts == MutedForever {
- return -1
- } else {
- return ts.Sub(now)
- }
-}
diff --git a/event/beeper.go b/event/beeper.go
index a1a60b35..921e3466 100644
--- a/event/beeper.go
+++ b/event/beeper.go
@@ -53,8 +53,6 @@ type BeeperMessageStatusEventContent struct {
LastRetry id.EventID `json:"last_retry,omitempty"`
- TargetTxnID string `json:"relates_to_txn_id,omitempty"`
-
MutateEventKey string `json:"mutate_event_key,omitempty"`
// Indicates the set of users to whom the event was delivered. If nil, then
@@ -88,22 +86,6 @@ type BeeperRoomKeyAckEventContent struct {
FirstMessageIndex int `json:"first_message_index"`
}
-type BeeperChatDeleteEventContent struct {
- DeleteForEveryone bool `json:"delete_for_everyone,omitempty"`
- FromMessageRequest bool `json:"from_message_request,omitempty"`
-}
-
-type BeeperAcceptMessageRequestEventContent struct {
- // Whether this was triggered by a message rather than an explicit event
- IsImplicit bool `json:"-"`
-}
-
-type BeeperSendStateEventContent struct {
- Type string `json:"type"`
- StateKey string `json:"state_key"`
- Content Content `json:"content"`
-}
-
type IntOrString int
func (ios *IntOrString) UnmarshalJSON(data []byte) error {
@@ -146,7 +128,6 @@ type BeeperLinkPreview struct {
MatchedURL string `json:"matched_url,omitempty"`
ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
- ImageBlurhash string `json:"matrix:image:blurhash,omitempty"`
}
type BeeperProfileExtra struct {
@@ -166,24 +147,6 @@ type BeeperPerMessageProfile struct {
HasFallback bool `json:"has_fallback,omitempty"`
}
-type BeeperActionMessageType string
-
-const (
- BeeperActionMessageCall BeeperActionMessageType = "call"
-)
-
-type BeeperActionMessageCallType string
-
-const (
- BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice"
- BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video"
-)
-
-type BeeperActionMessage struct {
- Type BeeperActionMessageType `json:"type"`
- CallType BeeperActionMessageCallType `json:"call_type,omitempty"`
-}
-
func (content *MessageEventContent) AddPerMessageProfileFallback() {
if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" {
return
@@ -214,15 +177,6 @@ func (content *MessageEventContent) RemovePerMessageProfileFallback() {
}
}
-type BeeperAIStreamEventContent struct {
- TurnID string `json:"turn_id"`
- Seq int `json:"seq"`
- Part map[string]any `json:"part"`
- TargetEvent id.EventID `json:"target_event,omitempty"`
- AgentID string `json:"agent_id,omitempty"`
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
-}
-
type BeeperEncodedOrder struct {
order int64
suborder int16
diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts
index 26aeb347..4cf29de7 100644
--- a/event/capabilities.d.ts
+++ b/event/capabilities.d.ts
@@ -16,23 +16,6 @@ export interface RoomFeatures {
* If a message type isn't listed here, it should be treated as support level -2 (will be rejected).
*/
file?: Record
- /**
- * Supported state event types and their parameters. Currently, there are no parameters,
- * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.).
- *
- * Events that are not listed or have a support level of zero or below should be treated as unsupported.
- *
- * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here.
- * `m.room.member` will not be listed here, as it's controlled by the member_actions field.
- * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now.
- */
- state?: Record
- /**
- * Supported member actions and their support levels.
- *
- * Actions that are not listed or have a support level of zero or below should be treated as unsupported.
- */
- member_actions?: Record
/** Maximum length of normal text messages. */
max_text_length?: integer
@@ -58,8 +41,6 @@ export interface RoomFeatures {
delete_max_age?: seconds
/** Whether deleting messages just for yourself is supported. No message age limit. */
delete_for_me?: boolean
- /** Allowed configuration options for disappearing timers. */
- disappearing_timer?: DisappearingTimerCapability
/** Whether reactions are supported. */
reaction?: CapabilitySupportLevel
@@ -72,21 +53,10 @@ export interface RoomFeatures {
allowed_reactions?: string[]
/** Whether custom emoji reactions are allowed. */
custom_emoji_reactions?: boolean
-
- /** Whether deleting the chat for yourself is supported. */
- delete_chat?: boolean
- /** Whether deleting the chat for all participants is supported. */
- delete_chat_for_everyone?: boolean
- /** What can be done with message requests? */
- message_request?: {
- accept_with_message?: CapabilitySupportLevel
- accept_with_button?: CapabilitySupportLevel
- }
}
declare type integer = number
declare type seconds = integer
-declare type milliseconds = integer
declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application"
declare type MIMETypeOrPattern =
"*/*"
@@ -94,21 +64,6 @@ declare type MIMETypeOrPattern =
| `${MIMEClass}/${string}`
| `${MIMEClass}/${string}; ${string}`
-export enum MemberAction {
- Ban = "ban",
- Kick = "kick",
- Leave = "leave",
- RevokeInvite = "revoke_invite",
- Invite = "invite",
-}
-
-declare type EventType = string
-
-// This is an object for future extensibility (e.g. max name/topic length)
-export interface StateFeatures {
- level: CapabilitySupportLevel
-}
-
export enum CapabilityMsgType {
// Real message types used in the `msgtype` field
Image = "m.image",
@@ -151,25 +106,6 @@ export interface FileFeatures {
view_once?: boolean
}
-export enum DisappearingType {
- None = "",
- AfterRead = "after_read",
- AfterSend = "after_send",
-}
-
-export interface DisappearingTimerCapability {
- types: DisappearingType[]
- /** Allowed timer values. If omitted, any timer is allowed. */
- timers?: milliseconds[]
- /**
- * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear
- *
- * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't,
- * so the hardcoded features for Matrix rooms should set this to true, while bridges will not.
- */
- omit_empty_timer?: true
-}
-
/**
* The support level for a feature. These are integers rather than booleans
* to accurately represent what the bridge is doing and hopefully make the
diff --git a/event/capabilities.go b/event/capabilities.go
index a86c726b..9c9eb09a 100644
--- a/event/capabilities.go
+++ b/event/capabilities.go
@@ -18,7 +18,6 @@ import (
"go.mau.fi/util/exerrors"
"go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
"golang.org/x/exp/constraints"
"golang.org/x/exp/maps"
)
@@ -28,10 +27,8 @@ type RoomFeatures struct {
// N.B. New fields need to be added to the Hash function to be included in the deduplication hash.
- Formatting FormattingFeatureMap `json:"formatting,omitempty"`
- File FileFeatureMap `json:"file,omitempty"`
- State StateFeatureMap `json:"state,omitempty"`
- MemberActions MemberFeatureMap `json:"member_actions,omitempty"`
+ Formatting FormattingFeatureMap `json:"formatting,omitempty"`
+ File FileFeatureMap `json:"file,omitempty"`
MaxTextLength int `json:"max_text_length,omitempty"`
@@ -47,23 +44,16 @@ type RoomFeatures struct {
DeleteForMe bool `json:"delete_for_me,omitempty"`
DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"`
- DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"`
-
Reaction CapabilitySupportLevel `json:"reaction,omitempty"`
ReactionCount int `json:"reaction_count,omitempty"`
AllowedReactions []string `json:"allowed_reactions,omitempty"`
CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"`
- ReadReceipts bool `json:"read_receipts,omitempty"`
- TypingNotifications bool `json:"typing_notifications,omitempty"`
- Archive bool `json:"archive,omitempty"`
- MarkAsUnread bool `json:"mark_as_unread,omitempty"`
- DeleteChat bool `json:"delete_chat,omitempty"`
- DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"`
-
- MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"`
-
- PerMessageProfileRelay bool `json:"-"`
+ ReadReceipts bool `json:"read_receipts,omitempty"`
+ TypingNotifications bool `json:"typing_notifications,omitempty"`
+ Archive bool `json:"archive,omitempty"`
+ MarkAsUnread bool `json:"mark_as_unread,omitempty"`
+ DeleteChat bool `json:"delete_chat,omitempty"`
}
func (rf *RoomFeatures) GetID() string {
@@ -73,120 +63,10 @@ func (rf *RoomFeatures) GetID() string {
return base64.RawURLEncoding.EncodeToString(rf.Hash())
}
-func (rf *RoomFeatures) Clone() *RoomFeatures {
- if rf == nil {
- return nil
- }
- clone := *rf
- clone.File = clone.File.Clone()
- clone.Formatting = maps.Clone(clone.Formatting)
- clone.State = clone.State.Clone()
- clone.MemberActions = clone.MemberActions.Clone()
- clone.EditMaxAge = ptr.Clone(clone.EditMaxAge)
- clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge)
- clone.DisappearingTimer = clone.DisappearingTimer.Clone()
- clone.AllowedReactions = slices.Clone(clone.AllowedReactions)
- clone.MessageRequest = clone.MessageRequest.Clone()
- return &clone
-}
-
-type MemberFeatureMap map[MemberAction]CapabilitySupportLevel
-
-func (mfm MemberFeatureMap) Clone() MemberFeatureMap {
- return maps.Clone(mfm)
-}
-
-type MemberAction string
-
-const (
- MemberActionBan MemberAction = "ban"
- MemberActionKick MemberAction = "kick"
- MemberActionLeave MemberAction = "leave"
- MemberActionRevokeInvite MemberAction = "revoke_invite"
- MemberActionInvite MemberAction = "invite"
-)
-
-type StateFeatureMap map[string]*StateFeatures
-
-func (sfm StateFeatureMap) Clone() StateFeatureMap {
- dup := maps.Clone(sfm)
- for key, value := range dup {
- dup[key] = value.Clone()
- }
- return dup
-}
-
-type StateFeatures struct {
- Level CapabilitySupportLevel `json:"level"`
-}
-
-func (sf *StateFeatures) Clone() *StateFeatures {
- if sf == nil {
- return nil
- }
- clone := *sf
- return &clone
-}
-
-func (sf *StateFeatures) Hash() []byte {
- return sf.Level.Hash()
-}
-
type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel
type FileFeatureMap map[CapabilityMsgType]*FileFeatures
-func (ffm FileFeatureMap) Clone() FileFeatureMap {
- dup := maps.Clone(ffm)
- for key, value := range dup {
- dup[key] = value.Clone()
- }
- return dup
-}
-
-type DisappearingTimerCapability struct {
- Types []DisappearingType `json:"types"`
- Timers []jsontime.Milliseconds `json:"timers,omitempty"`
-
- OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"`
-}
-
-func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability {
- if dtc == nil {
- return nil
- }
- clone := *dtc
- clone.Types = slices.Clone(clone.Types)
- clone.Timers = slices.Clone(clone.Timers)
- return &clone
-}
-
-func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool {
- if dtc == nil || content == nil || content.Type == DisappearingTypeNone {
- return true
- }
- return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer))
-}
-
-type MessageRequestFeatures struct {
- AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"`
- AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"`
-}
-
-func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures {
- return ptr.Clone(mrf)
-}
-
-func (mrf *MessageRequestFeatures) Hash() []byte {
- if mrf == nil {
- return nil
- }
- hasher := sha256.New()
- hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage)
- hashValue(hasher, "accept_with_button", mrf.AcceptWithButton)
- return hasher.Sum(nil)
-}
-
type CapabilityMsgType = MessageType
// Message types which are used for event capability signaling, but aren't real values for the msgtype field.
@@ -336,8 +216,6 @@ func (rf *RoomFeatures) Hash() []byte {
hashMap(hasher, "formatting", rf.Formatting)
hashMap(hasher, "file", rf.File)
- hashMap(hasher, "state", rf.State)
- hashMap(hasher, "member_actions", rf.MemberActions)
hashInt(hasher, "max_text_length", rf.MaxTextLength)
@@ -353,7 +231,6 @@ func (rf *RoomFeatures) Hash() []byte {
hashValue(hasher, "delete", rf.Delete)
hashBool(hasher, "delete_for_me", rf.DeleteForMe)
hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get())
- hashValue(hasher, "disappearing_timer", rf.DisappearingTimer)
hashValue(hasher, "reaction", rf.Reaction)
hashInt(hasher, "reaction_count", rf.ReactionCount)
@@ -368,28 +245,10 @@ func (rf *RoomFeatures) Hash() []byte {
hashBool(hasher, "archive", rf.Archive)
hashBool(hasher, "mark_as_unread", rf.MarkAsUnread)
hashBool(hasher, "delete_chat", rf.DeleteChat)
- hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone)
- hashValue(hasher, "message_request", rf.MessageRequest)
return hasher.Sum(nil)
}
-func (dtc *DisappearingTimerCapability) Hash() []byte {
- if dtc == nil {
- return nil
- }
- hasher := sha256.New()
- hasher.Write([]byte("types"))
- for _, t := range dtc.Types {
- hasher.Write([]byte(t))
- }
- hasher.Write([]byte("timers"))
- for _, timer := range dtc.Timers {
- hashInt(hasher, "", timer.Milliseconds())
- }
- return hasher.Sum(nil)
-}
-
func (ff *FileFeatures) Hash() []byte {
hasher := sha256.New()
hashMap(hasher, "mime_types", ff.MimeTypes)
@@ -402,13 +261,3 @@ func (ff *FileFeatures) Hash() []byte {
hashBool(hasher, "view_once", ff.ViewOnce)
return hasher.Sum(nil)
}
-
-func (ff *FileFeatures) Clone() *FileFeatures {
- if ff == nil {
- return nil
- }
- clone := *ff
- clone.MimeTypes = maps.Clone(clone.MimeTypes)
- clone.MaxDuration = ptr.Clone(clone.MaxDuration)
- return &clone
-}
diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go
deleted file mode 100644
index ce07c4c0..00000000
--- a/event/cmdschema/content.go
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package cmdschema
-
-import (
- "crypto/sha256"
- "encoding/base64"
- "fmt"
- "reflect"
- "slices"
-
- "go.mau.fi/util/exsync"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-type EventContent struct {
- Command string `json:"command"`
- Aliases []string `json:"aliases,omitempty"`
- Parameters []*Parameter `json:"parameters,omitempty"`
- Description *event.ExtensibleTextContainer `json:"description,omitempty"`
- TailParam string `json:"fi.mau.tail_parameter,omitempty"`
-}
-
-func (ec *EventContent) Validate() error {
- if ec == nil {
- return fmt.Errorf("event content is nil")
- } else if ec.Command == "" {
- return fmt.Errorf("command is empty")
- }
- var tailFound bool
- dupMap := exsync.NewSet[string]()
- for i, p := range ec.Parameters {
- if err := p.Validate(); err != nil {
- return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err)
- } else if !dupMap.Add(p.Key) {
- return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1)
- } else if p.Key == ec.TailParam {
- tailFound = true
- } else if tailFound && !p.Optional {
- return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam)
- }
- }
- if ec.TailParam != "" && !tailFound {
- return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam)
- }
- return nil
-}
-
-func (ec *EventContent) IsValid() bool {
- return ec.Validate() == nil
-}
-
-func (ec *EventContent) StateKey(owner id.UserID) string {
- hash := sha256.Sum256([]byte(ec.Command + owner.String()))
- return base64.StdEncoding.EncodeToString(hash[:])
-}
-
-func (ec *EventContent) Equals(other *EventContent) bool {
- if ec == nil || other == nil {
- return ec == other
- }
- return ec.Command == other.Command &&
- slices.Equal(ec.Aliases, other.Aliases) &&
- slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) &&
- ec.Description.Equals(other.Description) &&
- ec.TailParam == other.TailParam
-}
-
-func init() {
- event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{})
-}
diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go
deleted file mode 100644
index 4193b297..00000000
--- a/event/cmdschema/parameter.go
+++ /dev/null
@@ -1,286 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package cmdschema
-
-import (
- "encoding/json"
- "fmt"
- "slices"
-
- "go.mau.fi/util/exslices"
-
- "maunium.net/go/mautrix/event"
-)
-
-type Parameter struct {
- Key string `json:"key"`
- Schema *ParameterSchema `json:"schema"`
- Optional bool `json:"optional,omitempty"`
- Description *event.ExtensibleTextContainer `json:"description,omitempty"`
- DefaultValue any `json:"fi.mau.default_value,omitempty"`
-}
-
-func (p *Parameter) Equals(other *Parameter) bool {
- if p == nil || other == nil {
- return p == other
- }
- return p.Key == other.Key &&
- p.Schema.Equals(other.Schema) &&
- p.Optional == other.Optional &&
- p.Description.Equals(other.Description) &&
- p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values
-}
-
-func (p *Parameter) Validate() error {
- if p == nil {
- return fmt.Errorf("parameter is nil")
- } else if p.Key == "" {
- return fmt.Errorf("key is empty")
- }
- return p.Schema.Validate()
-}
-
-func (p *Parameter) IsValid() bool {
- return p.Validate() == nil
-}
-
-func (p *Parameter) GetDefaultValue() any {
- if p != nil && p.DefaultValue != nil {
- return p.DefaultValue
- } else if p == nil || p.Optional {
- return nil
- }
- return p.Schema.GetDefaultValue()
-}
-
-type PrimitiveType string
-
-const (
- PrimitiveTypeString PrimitiveType = "string"
- PrimitiveTypeInteger PrimitiveType = "integer"
- PrimitiveTypeBoolean PrimitiveType = "boolean"
- PrimitiveTypeServerName PrimitiveType = "server_name"
- PrimitiveTypeUserID PrimitiveType = "user_id"
- PrimitiveTypeRoomID PrimitiveType = "room_id"
- PrimitiveTypeRoomAlias PrimitiveType = "room_alias"
- PrimitiveTypeEventID PrimitiveType = "event_id"
-)
-
-func (pt PrimitiveType) Schema() *ParameterSchema {
- return &ParameterSchema{
- SchemaType: SchemaTypePrimitive,
- Type: pt,
- }
-}
-
-func (pt PrimitiveType) IsValid() bool {
- switch pt {
- case PrimitiveTypeString,
- PrimitiveTypeInteger,
- PrimitiveTypeBoolean,
- PrimitiveTypeServerName,
- PrimitiveTypeUserID,
- PrimitiveTypeRoomID,
- PrimitiveTypeRoomAlias,
- PrimitiveTypeEventID:
- return true
- default:
- return false
- }
-}
-
-type SchemaType string
-
-const (
- SchemaTypePrimitive SchemaType = "primitive"
- SchemaTypeArray SchemaType = "array"
- SchemaTypeUnion SchemaType = "union"
- SchemaTypeLiteral SchemaType = "literal"
-)
-
-type ParameterSchema struct {
- SchemaType SchemaType `json:"schema_type"`
- Type PrimitiveType `json:"type,omitempty"` // Only for primitive
- Items *ParameterSchema `json:"items,omitempty"` // Only for array
- Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union
- Value any `json:"value,omitempty"` // Only for literal
-}
-
-func Literal(value any) *ParameterSchema {
- return &ParameterSchema{
- SchemaType: SchemaTypeLiteral,
- Value: value,
- }
-}
-
-func Enum(values ...any) *ParameterSchema {
- return Union(exslices.CastFunc(values, Literal)...)
-}
-
-func flattenUnion(variants []*ParameterSchema) []*ParameterSchema {
- var flattened []*ParameterSchema
- for _, variant := range variants {
- switch variant.SchemaType {
- case SchemaTypeArray:
- panic(fmt.Errorf("illegal array schema in union"))
- case SchemaTypeUnion:
- flattened = append(flattened, flattenUnion(variant.Variants)...)
- default:
- flattened = append(flattened, variant)
- }
- }
- return flattened
-}
-
-func Union(variants ...*ParameterSchema) *ParameterSchema {
- needsFlattening := false
- for _, variant := range variants {
- if variant.SchemaType == SchemaTypeArray {
- panic(fmt.Errorf("illegal array schema in union"))
- } else if variant.SchemaType == SchemaTypeUnion {
- needsFlattening = true
- }
- }
- if needsFlattening {
- variants = flattenUnion(variants)
- }
- return &ParameterSchema{
- SchemaType: SchemaTypeUnion,
- Variants: variants,
- }
-}
-
-func Array(items *ParameterSchema) *ParameterSchema {
- if items.SchemaType == SchemaTypeArray {
- panic(fmt.Errorf("illegal array schema in array"))
- }
- return &ParameterSchema{
- SchemaType: SchemaTypeArray,
- Items: items,
- }
-}
-
-func (ps *ParameterSchema) GetDefaultValue() any {
- if ps == nil {
- return nil
- }
- switch ps.SchemaType {
- case SchemaTypePrimitive:
- switch ps.Type {
- case PrimitiveTypeInteger:
- return 0
- case PrimitiveTypeBoolean:
- return false
- default:
- return ""
- }
- case SchemaTypeArray:
- return []any{}
- case SchemaTypeUnion:
- if len(ps.Variants) > 0 {
- return ps.Variants[0].GetDefaultValue()
- }
- return nil
- case SchemaTypeLiteral:
- return ps.Value
- default:
- return nil
- }
-}
-
-func (ps *ParameterSchema) IsValid() bool {
- return ps.validate("") == nil
-}
-
-func (ps *ParameterSchema) Validate() error {
- return ps.validate("")
-}
-
-func (ps *ParameterSchema) validate(parent SchemaType) error {
- if ps == nil {
- return fmt.Errorf("schema is nil")
- }
- switch ps.SchemaType {
- case SchemaTypePrimitive:
- if !ps.Type.IsValid() {
- return fmt.Errorf("invalid primitive type %s", ps.Type)
- } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil {
- return fmt.Errorf("primitive schema has extra fields")
- }
- return nil
- case SchemaTypeArray:
- if parent != "" {
- return fmt.Errorf("arrays can't be nested in other types")
- } else if err := ps.Items.validate(ps.SchemaType); err != nil {
- return fmt.Errorf("item schema is invalid: %w", err)
- } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil {
- return fmt.Errorf("array schema has extra fields")
- }
- return nil
- case SchemaTypeUnion:
- if len(ps.Variants) == 0 {
- return fmt.Errorf("no variants specified for union")
- } else if parent != "" && parent != SchemaTypeArray {
- return fmt.Errorf("unions can't be nested in anything other than arrays")
- }
- for i, v := range ps.Variants {
- if err := v.validate(ps.SchemaType); err != nil {
- return fmt.Errorf("variant #%d is invalid: %w", i+1, err)
- }
- }
- if ps.Type != "" || ps.Items != nil || ps.Value != nil {
- return fmt.Errorf("union schema has extra fields")
- }
- return nil
- case SchemaTypeLiteral:
- switch typedVal := ps.Value.(type) {
- case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue:
- // ok
- case map[string]any:
- if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" {
- return fmt.Errorf("literal value has invalid map data")
- }
- default:
- return fmt.Errorf("literal value has unsupported type %T", ps.Value)
- }
- if ps.Type != "" || ps.Items != nil || ps.Variants != nil {
- return fmt.Errorf("literal schema has extra fields")
- }
- return nil
- default:
- return fmt.Errorf("invalid schema type %s", ps.SchemaType)
- }
-}
-
-func (ps *ParameterSchema) Equals(other *ParameterSchema) bool {
- if ps == nil || other == nil {
- return ps == other
- }
- return ps.SchemaType == other.SchemaType &&
- ps.Type == other.Type &&
- ps.Items.Equals(other.Items) &&
- slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) &&
- ps.Value == other.Value // TODO this won't work for room/event ID values
-}
-
-func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool {
- switch ps.SchemaType {
- case SchemaTypePrimitive:
- return ps.Type == prim
- case SchemaTypeUnion:
- for _, variant := range ps.Variants {
- if variant.AllowsPrimitive(prim) {
- return true
- }
- }
- return false
- case SchemaTypeArray:
- return ps.Items.AllowsPrimitive(prim)
- default:
- return false
- }
-}
diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go
deleted file mode 100644
index 92e69b60..00000000
--- a/event/cmdschema/parse.go
+++ /dev/null
@@ -1,478 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package cmdschema
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "regexp"
- "strconv"
- "strings"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-const botArrayOpener = "<"
-const botArrayCloser = ">"
-
-func parseQuoted(val string) (parsed, remaining string, quoted bool) {
- if len(val) == 0 {
- return
- }
- if !strings.HasPrefix(val, `"`) {
- spaceIdx := strings.IndexByte(val, ' ')
- if spaceIdx == -1 {
- parsed = val
- } else {
- parsed = val[:spaceIdx]
- remaining = strings.TrimLeft(val[spaceIdx+1:], " ")
- }
- return
- }
- val = val[1:]
- var buf strings.Builder
- for {
- quoteIdx := strings.IndexByte(val, '"')
- var valUntilQuote string
- if quoteIdx == -1 {
- valUntilQuote = val
- } else {
- valUntilQuote = val[:quoteIdx]
- }
- escapeIdx := strings.IndexByte(valUntilQuote, '\\')
- if escapeIdx >= 0 {
- buf.WriteString(val[:escapeIdx])
- if len(val) > escapeIdx+1 {
- buf.WriteByte(val[escapeIdx+1])
- }
- val = val[min(escapeIdx+2, len(val)):]
- } else if quoteIdx >= 0 {
- buf.WriteString(val[:quoteIdx])
- val = val[quoteIdx+1:]
- break
- } else if buf.Len() == 0 {
- // Unterminated quote, no escape characters, val is the whole input
- return val, "", true
- } else {
- // Unterminated quote, but there were escape characters previously
- buf.WriteString(val)
- val = ""
- break
- }
- }
- return buf.String(), strings.TrimLeft(val, " "), true
-}
-
-// ParseInput tries to parse the given text into a bot command event matching this command definition.
-//
-// If the prefix doesn't match, this will return a nil content and nil error.
-// If the prefix does match, some content is always returned, but there may still be an error if parsing failed.
-func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) {
- prefix := ec.parsePrefix(input, sigils, owner.String())
- if prefix == "" {
- return nil, nil
- }
- content = &event.MessageEventContent{
- MsgType: event.MsgText,
- Body: input,
- Mentions: &event.Mentions{UserIDs: []id.UserID{owner}},
- MSC4391BotCommand: &event.MSC4391BotCommandInput{
- Command: ec.Command,
- },
- }
- content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):])
- return content, err
-}
-
-func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) {
- args := make(map[string]any)
- var retErr error
- setError := func(err error) {
- if err != nil && retErr == nil {
- retErr = err
- }
- }
- processParameter := func(param *Parameter, isLast, isTail, isNamed bool) {
- origInput := input
- var nextVal string
- var wasQuoted bool
- if param.Schema.SchemaType == SchemaTypeArray {
- hasOpener := strings.HasPrefix(input, botArrayOpener)
- arrayClosed := false
- if hasOpener {
- input = input[len(botArrayOpener):]
- if strings.HasPrefix(input, botArrayCloser) {
- input = strings.TrimLeft(input[len(botArrayCloser):], " ")
- arrayClosed = true
- }
- }
- var collector []any
- for len(input) > 0 && !arrayClosed {
- //origInput = input
- nextVal, input, wasQuoted = parseQuoted(input)
- if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) {
- // The value wasn't quoted and has the array delimiter at the end, close the array
- nextVal = strings.TrimRight(nextVal, botArrayCloser)
- arrayClosed = true
- } else if hasOpener && strings.HasPrefix(input, botArrayCloser) {
- // The value was quoted or there was a space, and the next character is the
- // array delimiter, close the array
- input = strings.TrimLeft(input[len(botArrayCloser):], " ")
- arrayClosed = true
- } else if !hasOpener && !isLast {
- // For array arguments in the middle without the <> delimiters, stop after the first item
- arrayClosed = true
- }
- parsedVal, err := param.Schema.Items.ParseString(nextVal)
- if err == nil {
- collector = append(collector, parsedVal)
- } else if hasOpener || isLast {
- setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err))
- } else {
- //input = origInput
- }
- }
- args[param.Key] = collector
- } else {
- nextVal, input, wasQuoted = parseQuoted(input)
- if (isLast || isTail) && !wasQuoted && len(input) > 0 {
- // If the last argument is not quoted, just treat the rest of the string
- // as the argument without escapes (arguments with escapes should be quoted).
- nextVal += " " + input
- input = ""
- }
- // Special case for named boolean parameters: if no value is given, treat it as true
- if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) {
- args[param.Key] = true
- return
- }
- if nextVal == "" && !wasQuoted && !isNamed && !param.Optional {
- setError(fmt.Errorf("missing value for required parameter %s", param.Key))
- }
- parsedVal, err := param.Schema.ParseString(nextVal)
- if err != nil {
- args[param.Key] = param.GetDefaultValue()
- // For optional parameters that fail to parse, restore the input and try passing it as the next parameter
- if param.Optional && !isLast && !isNamed {
- input = strings.TrimLeft(origInput, " ")
- } else if !param.Optional || isNamed {
- setError(fmt.Errorf("failed to parse %s: %w", param.Key, err))
- }
- } else {
- args[param.Key] = parsedVal
- }
- }
- }
- skipParams := make([]bool, len(ec.Parameters))
- for i, param := range ec.Parameters {
- for strings.HasPrefix(input, "--") {
- nameEndIdx := strings.IndexAny(input, " =")
- if nameEndIdx == -1 {
- nameEndIdx = len(input)
- }
- overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx])
- if overrideParam != nil {
- // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input
- input = strings.TrimPrefix(input[nameEndIdx:], "=")
- skipParams[paramIdx] = true
- processParameter(overrideParam, false, false, true)
- } else {
- break
- }
- }
- isTail := param.Key == ec.TailParam
- if skipParams[i] || (param.Optional && !isTail) {
- continue
- }
- processParameter(param, i == len(ec.Parameters)-1, isTail, false)
- }
- jsonArgs, marshalErr := json.Marshal(args)
- if marshalErr != nil {
- return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr)
- }
- return jsonArgs, retErr
-}
-
-func (ec *EventContent) parameterByName(name string) (*Parameter, int) {
- for i, param := range ec.Parameters {
- if strings.EqualFold(param.Key, name) {
- return param, i
- }
- }
- return nil, -1
-}
-
-func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) {
- input := origInput
- var chosenSigil string
- for _, sigil := range sigils {
- if strings.HasPrefix(input, sigil) {
- chosenSigil = sigil
- break
- }
- }
- if chosenSigil == "" {
- return ""
- }
- input = input[len(chosenSigil):]
- var chosenAlias string
- if !strings.HasPrefix(input, ec.Command) {
- for _, alias := range ec.Aliases {
- if strings.HasPrefix(input, alias) {
- chosenAlias = alias
- break
- }
- }
- if chosenAlias == "" {
- return ""
- }
- } else {
- chosenAlias = ec.Command
- }
- input = strings.TrimPrefix(input[len(chosenAlias):], owner)
- if input == "" || input[0] == ' ' {
- input = strings.TrimLeft(input, " ")
- return origInput[:len(origInput)-len(input)]
- }
- return ""
-}
-
-func (pt PrimitiveType) ValidateValue(value any) bool {
- _, err := pt.NormalizeValue(value)
- return err == nil
-}
-
-func normalizeNumber(value any) (int, error) {
- switch typedValue := value.(type) {
- case int:
- return typedValue, nil
- case int64:
- return int(typedValue), nil
- case float64:
- return int(typedValue), nil
- case json.Number:
- if i, err := typedValue.Int64(); err != nil {
- return 0, fmt.Errorf("failed to parse json.Number: %w", err)
- } else {
- return int(i), nil
- }
- default:
- return 0, fmt.Errorf("unsupported type %T for integer", value)
- }
-}
-
-func (pt PrimitiveType) NormalizeValue(value any) (any, error) {
- switch pt {
- case PrimitiveTypeInteger:
- return normalizeNumber(value)
- case PrimitiveTypeBoolean:
- bv, ok := value.(bool)
- if !ok {
- return nil, fmt.Errorf("unsupported type %T for boolean", value)
- }
- return bv, nil
- case PrimitiveTypeString, PrimitiveTypeServerName:
- str, ok := value.(string)
- if !ok {
- return nil, fmt.Errorf("unsupported type %T for string", value)
- }
- return str, pt.validateStringValue(str)
- case PrimitiveTypeUserID, PrimitiveTypeRoomAlias:
- str, ok := value.(string)
- if !ok {
- return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value)
- } else if plainErr := pt.validateStringValue(str); plainErr == nil {
- return str, nil
- } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil {
- return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err)
- } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID {
- return parsed.UserID(), nil
- } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias {
- return parsed.RoomAlias(), nil
- } else {
- return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1)
- }
- case PrimitiveTypeRoomID, PrimitiveTypeEventID:
- riv, err := NormalizeRoomIDValue(value)
- if err != nil {
- return nil, err
- }
- return riv, riv.Validate()
- default:
- return nil, fmt.Errorf("cannot normalize value for argument type %s", pt)
- }
-}
-
-func (pt PrimitiveType) validateStringValue(value string) error {
- switch pt {
- case PrimitiveTypeString:
- return nil
- case PrimitiveTypeServerName:
- if !id.ValidateServerName(value) {
- return fmt.Errorf("invalid server name: %q", value)
- }
- return nil
- case PrimitiveTypeUserID:
- _, _, err := id.UserID(value).ParseAndValidateRelaxed()
- return err
- case PrimitiveTypeRoomAlias:
- sigil, localpart, serverName := id.ParseCommonIdentifier(value)
- if sigil != '#' || localpart == "" || serverName == "" {
- return fmt.Errorf("invalid room alias: %q", value)
- } else if !id.ValidateServerName(serverName) {
- return fmt.Errorf("invalid server name in room alias: %q", serverName)
- }
- return nil
- default:
- panic(fmt.Errorf("validateStringValue called with invalid type %s", pt))
- }
-}
-
-func parseBoolean(val string) (bool, error) {
- if len(val) == 0 {
- return false, fmt.Errorf("cannot parse empty string as boolean")
- }
- switch strings.ToLower(val) {
- case "t", "true", "y", "yes", "1":
- return true, nil
- case "f", "false", "n", "no", "0":
- return false, nil
- default:
- return false, fmt.Errorf("invalid boolean string: %q", val)
- }
-}
-
-var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`)
-
-func parseRoomOrEventID(value string) (*RoomIDValue, error) {
- if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") {
- matches := markdownLinkRegex.FindStringSubmatch(value)
- if len(matches) == 2 {
- value = matches[1]
- }
- }
- parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
- if err != nil && strings.HasPrefix(value, "!") {
- return &RoomIDValue{
- Type: PrimitiveTypeRoomID,
- RoomID: id.RoomID(value),
- }, nil
- }
- if err != nil {
- return nil, err
- } else if parsed.Sigil1 != '!' {
- return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1)
- } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' {
- return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2)
- }
- valType := PrimitiveTypeRoomID
- if parsed.MXID2 != "" {
- valType = PrimitiveTypeEventID
- }
- return &RoomIDValue{
- Type: valType,
- RoomID: parsed.RoomID(),
- Via: parsed.Via,
- EventID: parsed.EventID(),
- }, nil
-}
-
-func (pt PrimitiveType) ParseString(value string) (any, error) {
- switch pt {
- case PrimitiveTypeInteger:
- return strconv.Atoi(value)
- case PrimitiveTypeBoolean:
- return parseBoolean(value)
- case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID:
- return value, pt.validateStringValue(value)
- case PrimitiveTypeRoomAlias:
- plainErr := pt.validateStringValue(value)
- if plainErr == nil {
- return value, nil
- }
- parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
- if err != nil {
- return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err)
- } else if parsed.Sigil1 != '#' {
- return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1)
- }
- return parsed.RoomAlias(), nil
- case PrimitiveTypeRoomID, PrimitiveTypeEventID:
- parsed, err := parseRoomOrEventID(value)
- if err != nil {
- return nil, err
- } else if pt != parsed.Type {
- return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type)
- }
- return parsed, nil
- default:
- return nil, fmt.Errorf("cannot parse string for argument type %s", pt)
- }
-}
-
-func (ps *ParameterSchema) ParseString(value string) (any, error) {
- if ps == nil {
- return nil, fmt.Errorf("parameter schema is nil")
- }
- switch ps.SchemaType {
- case SchemaTypePrimitive:
- return ps.Type.ParseString(value)
- case SchemaTypeLiteral:
- switch typedValue := ps.Value.(type) {
- case string:
- if value == typedValue {
- return typedValue, nil
- } else {
- return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value)
- }
- case int, int64, float64, json.Number:
- expectedVal, _ := normalizeNumber(typedValue)
- intVal, err := strconv.Atoi(value)
- if err != nil {
- return nil, fmt.Errorf("failed to parse integer literal: %w", err)
- } else if intVal != expectedVal {
- return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal)
- }
- return intVal, nil
- case bool:
- boolVal, err := parseBoolean(value)
- if err != nil {
- return nil, fmt.Errorf("failed to parse boolean literal: %w", err)
- } else if boolVal != typedValue {
- return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal)
- }
- return boolVal, nil
- case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage:
- expectedVal, _ := NormalizeRoomIDValue(typedValue)
- parsed, err := parseRoomOrEventID(value)
- if err != nil {
- return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err)
- } else if !parsed.Equals(expectedVal) {
- return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed)
- }
- return parsed, nil
- default:
- return nil, fmt.Errorf("unsupported literal type %T", ps.Value)
- }
- case SchemaTypeUnion:
- var errs []error
- for _, variant := range ps.Variants {
- if parsed, err := variant.ParseString(value); err == nil {
- return parsed, nil
- } else {
- errs = append(errs, err)
- }
- }
- return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...))
- case SchemaTypeArray:
- return nil, fmt.Errorf("cannot parse string for array schema type")
- default:
- return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType)
- }
-}
diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go
deleted file mode 100644
index 1e0d1817..00000000
--- a/event/cmdschema/parse_test.go
+++ /dev/null
@@ -1,118 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package cmdschema
-
-import (
- "bytes"
- "encoding/json"
- "strings"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "go.mau.fi/util/exbytes"
- "go.mau.fi/util/exerrors"
-
- "maunium.net/go/mautrix/event/cmdschema/testdata"
-)
-
-type QuoteParseOutput struct {
- Parsed string
- Remaining string
- Quoted bool
-}
-
-func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error {
- var arr []any
- if err := json.Unmarshal(data, &arr); err != nil {
- return err
- }
- qpo.Parsed = arr[0].(string)
- qpo.Remaining = arr[1].(string)
- qpo.Quoted = arr[2].(bool)
- return nil
-}
-
-type QuoteParseTestData struct {
- Name string `json:"name"`
- Input string `json:"input"`
- Output QuoteParseOutput `json:"output"`
-}
-
-func loadFile[T any](name string) (into T) {
- quoteData := exerrors.Must(testdata.FS.ReadFile(name))
- exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into))
- return
-}
-
-func TestParseQuoted(t *testing.T) {
- qptd := loadFile[[]QuoteParseTestData]("parse_quote.json")
- for _, test := range qptd {
- t.Run(test.Name, func(t *testing.T) {
- parsed, remaining, quoted := parseQuoted(test.Input)
- assert.Equalf(t, test.Output, QuoteParseOutput{
- Parsed: parsed,
- Remaining: remaining,
- Quoted: quoted,
- }, "Failed with input `%s`", test.Input)
- // Note: can't just test that requoted == input, because some inputs
- // have unnecessary escapes which won't survive roundtripping
- t.Run("roundtrip", func(t *testing.T) {
- requoted := quoteString(parsed) + " " + remaining
- reparsed, newRemaining, _ := parseQuoted(requoted)
- assert.Equal(t, parsed, reparsed)
- assert.Equal(t, remaining, newRemaining)
- })
- })
- }
-}
-
-type CommandTestData struct {
- Spec *EventContent
- Tests []*CommandTestUnit
-}
-
-type CommandTestUnit struct {
- Name string `json:"name"`
- Input string `json:"input"`
- Broken string `json:"broken,omitempty"`
- Error bool `json:"error"`
- Output json.RawMessage `json:"output"`
-}
-
-func compactJSON(input json.RawMessage) json.RawMessage {
- var buf bytes.Buffer
- exerrors.PanicIfNotNil(json.Compact(&buf, input))
- return buf.Bytes()
-}
-
-func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) {
- for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) {
- t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) {
- ctd := loadFile[CommandTestData]("commands/" + cmd.Name())
- for _, test := range ctd.Tests {
- outputStr := exbytes.UnsafeString(compactJSON(test.Output))
- t.Run(test.Name, func(t *testing.T) {
- if test.Broken != "" {
- t.Skip(test.Broken)
- }
- output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input)
- if test.Error {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- }
- if outputStr == "null" {
- assert.Nil(t, output)
- } else {
- assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command)
- assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input)
- }
- })
- }
- })
- }
-}
diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go
deleted file mode 100644
index 98c421fc..00000000
--- a/event/cmdschema/roomid.go
+++ /dev/null
@@ -1,135 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package cmdschema
-
-import (
- "encoding/json"
- "fmt"
- "slices"
- "strings"
-
- "maunium.net/go/mautrix/id"
-)
-
-var ParameterSchemaJoinableRoom = Union(
- PrimitiveTypeRoomID.Schema(),
- PrimitiveTypeRoomAlias.Schema(),
-)
-
-type RoomIDValue struct {
- Type PrimitiveType `json:"type"`
- RoomID id.RoomID `json:"id"`
- Via []string `json:"via,omitempty"`
- EventID id.EventID `json:"event_id,omitempty"`
-}
-
-func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) {
- switch typedValue := input.(type) {
- case map[string]any, json.RawMessage:
- var raw json.RawMessage
- if raw, err = json.Marshal(input); err != nil {
- err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
- } else if err = json.Unmarshal(raw, &riv); err != nil {
- err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
- }
- case *RoomIDValue:
- riv = typedValue
- case RoomIDValue:
- riv = &typedValue
- default:
- err = fmt.Errorf("unsupported type %T for room or event ID", input)
- }
- return
-}
-
-func (riv *RoomIDValue) String() string {
- return riv.URI().String()
-}
-
-func (riv *RoomIDValue) URI() *id.MatrixURI {
- if riv == nil {
- return nil
- }
- switch riv.Type {
- case PrimitiveTypeRoomID:
- return riv.RoomID.URI(riv.Via...)
- case PrimitiveTypeEventID:
- return riv.RoomID.EventURI(riv.EventID, riv.Via...)
- default:
- return nil
- }
-}
-
-func (riv *RoomIDValue) Equals(other *RoomIDValue) bool {
- if riv == nil || other == nil {
- return riv == other
- }
- return riv.Type == other.Type &&
- riv.RoomID == other.RoomID &&
- riv.EventID == other.EventID &&
- slices.Equal(riv.Via, other.Via)
-}
-
-func (riv *RoomIDValue) Validate() error {
- if riv == nil {
- return fmt.Errorf("value is nil")
- }
- switch riv.Type {
- case PrimitiveTypeRoomID:
- if riv.EventID != "" {
- return fmt.Errorf("event ID must be empty for room ID type")
- }
- case PrimitiveTypeEventID:
- if !strings.HasPrefix(riv.EventID.String(), "$") {
- return fmt.Errorf("event ID not valid: %q", riv.EventID)
- }
- default:
- return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type)
- }
- for _, via := range riv.Via {
- if !id.ValidateServerName(via) {
- return fmt.Errorf("invalid server name %q in vias", via)
- }
- }
- sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID)
- if sigil != '!' {
- return fmt.Errorf("room ID does not start with !: %q", riv.RoomID)
- } else if localpart == "" && serverName == "" {
- return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID)
- } else if serverName != "" && !id.ValidateServerName(serverName) {
- return fmt.Errorf("invalid server name %q in room ID", serverName)
- }
- return nil
-}
-
-func (riv *RoomIDValue) IsValid() bool {
- return riv.Validate() == nil
-}
-
-type RoomIDOrString string
-
-func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error {
- if len(data) == 0 {
- return fmt.Errorf("empty data for room ID or string")
- }
- if data[0] == '"' {
- var str string
- if err := json.Unmarshal(data, &str); err != nil {
- return err
- }
- *ros = RoomIDOrString(str)
- return nil
- }
- var riv RoomIDValue
- if err := json.Unmarshal(data, &riv); err != nil {
- return err
- } else if err = riv.Validate(); err != nil {
- return err
- }
- *ros = RoomIDOrString(riv.String())
- return nil
-}
diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go
deleted file mode 100644
index c5c57c53..00000000
--- a/event/cmdschema/stringify.go
+++ /dev/null
@@ -1,122 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package cmdschema
-
-import (
- "encoding/json"
- "strconv"
- "strings"
-)
-
-var quoteEscaper = strings.NewReplacer(
- `"`, `\"`,
- `\`, `\\`,
-)
-
-const charsToQuote = ` \` + botArrayOpener + botArrayCloser
-
-func quoteString(val string) string {
- if val == "" {
- return `""`
- }
- val = quoteEscaper.Replace(val)
- if strings.ContainsAny(val, charsToQuote) {
- return `"` + val + `"`
- }
- return val
-}
-
-func (ec *EventContent) StringifyArgs(args any) string {
- var argMap map[string]any
- switch typedArgs := args.(type) {
- case json.RawMessage:
- err := json.Unmarshal(typedArgs, &argMap)
- if err != nil {
- return ""
- }
- case map[string]any:
- argMap = typedArgs
- default:
- if b, err := json.Marshal(args); err != nil {
- return ""
- } else if err = json.Unmarshal(b, &argMap); err != nil {
- return ""
- }
- }
- parts := make([]string, 0, len(ec.Parameters))
- for i, param := range ec.Parameters {
- isLast := i == len(ec.Parameters)-1
- val := argMap[param.Key]
- if val == nil {
- val = param.DefaultValue
- if val == nil && !param.Optional {
- val = param.Schema.GetDefaultValue()
- }
- }
- if val == nil {
- continue
- }
- var stringified string
- if param.Schema.SchemaType == SchemaTypeArray {
- stringified = arrayArgumentToString(val, isLast)
- } else {
- stringified = singleArgumentToString(val)
- }
- if stringified != "" {
- parts = append(parts, stringified)
- }
- }
- return strings.Join(parts, " ")
-}
-
-func arrayArgumentToString(val any, isLast bool) string {
- valArr, ok := val.([]any)
- if !ok {
- return ""
- }
- parts := make([]string, 0, len(valArr))
- for _, elem := range valArr {
- stringified := singleArgumentToString(elem)
- if stringified != "" {
- parts = append(parts, stringified)
- }
- }
- joinedParts := strings.Join(parts, " ")
- if isLast && len(parts) > 0 {
- return joinedParts
- }
- return botArrayOpener + joinedParts + botArrayCloser
-}
-
-func singleArgumentToString(val any) string {
- switch typedVal := val.(type) {
- case string:
- return quoteString(typedVal)
- case json.Number:
- return typedVal.String()
- case bool:
- return strconv.FormatBool(typedVal)
- case int:
- return strconv.Itoa(typedVal)
- case int64:
- return strconv.FormatInt(typedVal, 10)
- case float64:
- return strconv.FormatInt(int64(typedVal), 10)
- case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue:
- normalized, err := NormalizeRoomIDValue(typedVal)
- if err != nil {
- return ""
- }
- uri := normalized.URI()
- if uri == nil {
- return ""
- }
- return quoteString(uri.String())
- default:
- return ""
- }
-}
diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json
deleted file mode 100644
index e53382db..00000000
--- a/event/cmdschema/testdata/commands.schema.json
+++ /dev/null
@@ -1,281 +0,0 @@
-{
- "$schema": "https://json-schema.org/draft/2020-12/schema#",
- "$id": "commands.schema.json",
- "title": "ParseInput test cases",
- "description": "JSON schema for test case files containing command specifications and test cases",
- "type": "object",
- "required": [
- "spec",
- "tests"
- ],
- "additionalProperties": false,
- "properties": {
- "spec": {
- "title": "MSC4391 Command Description",
- "description": "JSON schema defining the structure of a bot command event content",
- "type": "object",
- "required": [
- "command"
- ],
- "additionalProperties": false,
- "properties": {
- "command": {
- "type": "string",
- "description": "The command name that triggers this bot command"
- },
- "aliases": {
- "type": "array",
- "description": "Alternative names/aliases for this command",
- "items": {
- "type": "string"
- }
- },
- "parameters": {
- "type": "array",
- "description": "List of parameters accepted by this command",
- "items": {
- "$ref": "#/$defs/Parameter"
- }
- },
- "description": {
- "$ref": "#/$defs/ExtensibleTextContainer",
- "description": "Human-readable description of the command"
- },
- "fi.mau.tail_parameter": {
- "type": "string",
- "description": "The key of the parameter that accepts remaining arguments as tail text"
- },
- "source": {
- "type": "string",
- "description": "The user ID of the bot that responds to this command"
- }
- }
- },
- "tests": {
- "type": "array",
- "description": "Array of test cases for the command",
- "items": {
- "type": "object",
- "description": "A single test case for command parsing",
- "required": [
- "name",
- "input"
- ],
- "additionalProperties": false,
- "properties": {
- "name": {
- "type": "string",
- "description": "The name of the test case"
- },
- "input": {
- "type": "string",
- "description": "The command input string to parse"
- },
- "output": {
- "description": "The expected parsed parameter values, or null if the parsing is expected to fail",
- "oneOf": [
- {
- "type": "object",
- "additionalProperties": true
- },
- {
- "type": "null"
- }
- ]
- },
- "error": {
- "type": "boolean",
- "description": "Whether parsing should result in an error. May still produce output.",
- "default": false
- }
- }
- }
- }
- },
- "$defs": {
- "ExtensibleTextContainer": {
- "type": "object",
- "description": "Container for text that can have multiple representations",
- "required": [
- "m.text"
- ],
- "properties": {
- "m.text": {
- "type": "array",
- "description": "Array of text representations in different formats",
- "items": {
- "$ref": "#/$defs/ExtensibleText"
- }
- }
- }
- },
- "ExtensibleText": {
- "type": "object",
- "description": "A text representation with a specific MIME type",
- "required": [
- "body"
- ],
- "properties": {
- "body": {
- "type": "string",
- "description": "The text content"
- },
- "mimetype": {
- "type": "string",
- "description": "The MIME type of the text (e.g., text/plain, text/html)",
- "default": "text/plain",
- "examples": [
- "text/plain",
- "text/html"
- ]
- }
- }
- },
- "Parameter": {
- "type": "object",
- "description": "A parameter definition for a command",
- "required": [
- "key",
- "schema"
- ],
- "additionalProperties": false,
- "properties": {
- "key": {
- "type": "string",
- "description": "The identifier for this parameter"
- },
- "schema": {
- "$ref": "#/$defs/ParameterSchema",
- "description": "The schema defining the type and structure of this parameter"
- },
- "optional": {
- "type": "boolean",
- "description": "Whether this parameter is optional",
- "default": false
- },
- "description": {
- "$ref": "#/$defs/ExtensibleTextContainer",
- "description": "Human-readable description of this parameter"
- },
- "fi.mau.default_value": {
- "description": "Default value for this parameter if not provided"
- }
- }
- },
- "ParameterSchema": {
- "type": "object",
- "description": "Schema definition for a parameter value",
- "required": [
- "schema_type"
- ],
- "additionalProperties": false,
- "properties": {
- "schema_type": {
- "type": "string",
- "enum": [
- "primitive",
- "array",
- "union",
- "literal"
- ],
- "description": "The type of schema"
- }
- },
- "allOf": [
- {
- "if": {
- "properties": {
- "schema_type": {
- "const": "primitive"
- }
- }
- },
- "then": {
- "required": [
- "type"
- ],
- "properties": {
- "type": {
- "type": "string",
- "enum": [
- "string",
- "integer",
- "boolean",
- "server_name",
- "user_id",
- "room_id",
- "room_alias",
- "event_id"
- ],
- "description": "The primitive type (only for schema_type: primitive)"
- }
- }
- }
- },
- {
- "if": {
- "properties": {
- "schema_type": {
- "const": "array"
- }
- }
- },
- "then": {
- "required": [
- "items"
- ],
- "properties": {
- "items": {
- "$ref": "#/$defs/ParameterSchema",
- "description": "The schema for array items (only for schema_type: array)"
- }
- }
- }
- },
- {
- "if": {
- "properties": {
- "schema_type": {
- "const": "union"
- }
- }
- },
- "then": {
- "required": [
- "variants"
- ],
- "properties": {
- "variants": {
- "type": "array",
- "description": "The possible variants (only for schema_type: union)",
- "items": {
- "$ref": "#/$defs/ParameterSchema"
- },
- "minItems": 1
- }
- }
- }
- },
- {
- "if": {
- "properties": {
- "schema_type": {
- "const": "literal"
- }
- }
- },
- "then": {
- "required": [
- "value"
- ],
- "properties": {
- "value": {
- "description": "The literal value (only for schema_type: literal)"
- }
- }
- }
- }
- ]
- }
- }
-}
diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json
deleted file mode 100644
index 6ce1f4da..00000000
--- a/event/cmdschema/testdata/commands/flags.json
+++ /dev/null
@@ -1,126 +0,0 @@
-{
- "$schema": "../commands.schema.json#",
- "spec": {
- "command": "flag",
- "source": "@testbot",
- "parameters": [
- {
- "key": "meow",
- "schema": {
- "schema_type": "primitive",
- "type": "string"
- }
- },
- {
- "key": "user",
- "schema": {
- "schema_type": "primitive",
- "type": "user_id"
- },
- "optional": true
- },
- {
- "key": "woof",
- "schema": {
- "schema_type": "primitive",
- "type": "boolean"
- },
- "optional": true,
- "fi.mau.default_value": false
- }
- ],
- "fi.mau.tail_parameter": "user"
- },
- "tests": [
- {
- "name": "no flags",
- "input": "/flag mrrp",
- "output": {
- "meow": "mrrp",
- "user": null
- }
- },
- {
- "name": "no flags, has tail",
- "input": "/flag mrrp @user:example.com",
- "output": {
- "meow": "mrrp",
- "user": "@user:example.com"
- }
- },
- {
- "name": "named flag at start",
- "input": "/flag --woof=yes mrrp @user:example.com",
- "output": {
- "meow": "mrrp",
- "user": "@user:example.com",
- "woof": true
- }
- },
- {
- "name": "boolean flag without value",
- "input": "/flag --woof mrrp @user:example.com",
- "output": {
- "meow": "mrrp",
- "user": "@user:example.com",
- "woof": true
- }
- },
- {
- "name": "user id flag without value",
- "input": "/flag --user --woof mrrp",
- "error": true,
- "output": {
- "meow": "mrrp",
- "user": null,
- "woof": true
- }
- },
- {
- "name": "named flag in the middle",
- "input": "/flag mrrp --woof=yes @user:example.com",
- "output": {
- "meow": "mrrp",
- "user": "@user:example.com",
- "woof": true
- }
- },
- {
- "name": "named flag in the middle with different value",
- "input": "/flag mrrp --woof=no @user:example.com",
- "output": {
- "meow": "mrrp",
- "user": "@user:example.com",
- "woof": false
- }
- },
- {
- "name": "all variables named",
- "input": "/flag --woof=no --meow=mrrp --user=@user:example.com",
- "output": {
- "meow": "mrrp",
- "user": "@user:example.com",
- "woof": false
- }
- },
- {
- "name": "all variables named with quotes",
- "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"",
- "output": {
- "meow": "meow meow mrrp",
- "user": "@user:example.com",
- "woof": true
- }
- },
- {
- "name": "invalid value for named parameter",
- "input": "/flag --user=meowings mrrp --woof",
- "error": true,
- "output": {
- "meow": "mrrp",
- "user": null,
- "woof": true
- }
- }
- ]
-}
diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json
deleted file mode 100644
index 1351c292..00000000
--- a/event/cmdschema/testdata/commands/room_id_or_alias.json
+++ /dev/null
@@ -1,85 +0,0 @@
-{
- "$schema": "../commands.schema.json#",
- "spec": {
- "command": "test room reference",
- "source": "@testbot",
- "parameters": [
- {
- "key": "room",
- "schema": {
- "schema_type": "union",
- "variants": [
- {
- "schema_type": "primitive",
- "type": "room_id"
- },
- {
- "schema_type": "primitive",
- "type": "room_alias"
- }
- ]
- }
- }
- ]
- },
- "tests": [
- {
- "name": "room alias",
- "input": "/test room reference #test:matrix.org",
- "output": {
- "room": "#test:matrix.org"
- }
- },
- {
- "name": "room id",
- "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
- "output": {
- "room": {
- "type": "room_id",
- "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
- }
- }
- },
- {
- "name": "room id matrix.to link",
- "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com",
- "output": {
- "room": {
- "type": "room_id",
- "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
- "via": [
- "example.com"
- ]
- }
- }
- },
- {
- "name": "room id matrix.to link with url encoding",
- "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net",
- "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly",
- "output": {
- "room": {
- "type": "room_id",
- "id": "!#test/room\nversion 11, with @🐈️:maunium.net",
- "via": [
- "maunium.net"
- ]
- }
- }
- },
- {
- "name": "room id matrix: URI",
- "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
- "output": {
- "room": {
- "type": "room_id",
- "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
- "via": [
- "maunium.net",
- "matrix.org"
- ]
- }
- }
- }
- ]
-}
diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json
deleted file mode 100644
index aa266054..00000000
--- a/event/cmdschema/testdata/commands/room_reference_list.json
+++ /dev/null
@@ -1,106 +0,0 @@
-{
- "$schema": "../commands.schema.json#",
- "spec": {
- "command": "test room reference",
- "source": "@testbot",
- "parameters": [
- {
- "key": "rooms",
- "schema": {
- "schema_type": "array",
- "items": {
- "schema_type": "union",
- "variants": [
- {
- "schema_type": "primitive",
- "type": "room_id"
- },
- {
- "schema_type": "primitive",
- "type": "room_alias"
- }
- ]
- }
- }
- }
- ]
- },
- "tests": [
- {
- "name": "room alias",
- "input": "/test room reference #test:matrix.org",
- "output": {
- "rooms": [
- "#test:matrix.org"
- ]
- }
- },
- {
- "name": "room id",
- "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
- "output": {
- "rooms": [
- {
- "type": "room_id",
- "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
- }
- ]
- }
- },
- {
- "name": "two room ids",
- "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org",
- "output": {
- "rooms": [
- {
- "type": "room_id",
- "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"
- },
- {
- "type": "room_id",
- "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
- }
- ]
- }
- },
- {
- "name": "room id matrix: URI",
- "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
- "output": {
- "rooms": [
- {
- "type": "room_id",
- "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
- "via": [
- "maunium.net",
- "matrix.org"
- ]
- }
- ]
- }
- },
- {
- "name": "room id matrix: URI and matrix.to URL",
- "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
- "output": {
- "rooms": [
- {
- "type": "room_id",
- "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
- "via": [
- "example.com"
- ]
- },
- {
- "type": "room_id",
- "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
- "via": [
- "maunium.net",
- "matrix.org"
- ]
- }
- ]
- }
- }
- ]
-}
diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json
deleted file mode 100644
index 94667323..00000000
--- a/event/cmdschema/testdata/commands/simple.json
+++ /dev/null
@@ -1,46 +0,0 @@
-{
- "$schema": "../commands.schema.json#",
- "spec": {
- "command": "test simple",
- "source": "@testbot",
- "parameters": [
- {
- "key": "meow",
- "schema": {
- "schema_type": "primitive",
- "type": "string"
- }
- }
- ]
- },
- "tests": [
- {
- "name": "success",
- "input": "/test simple mrrp",
- "output": {
- "meow": "mrrp"
- }
- },
- {
- "name": "directed success",
- "input": "/test simple@testbot mrrp",
- "output": {
- "meow": "mrrp"
- }
- },
- {
- "name": "missing parameter",
- "input": "/test simple",
- "error": true,
- "output": {
- "meow": ""
- }
- },
- {
- "name": "directed at another bot",
- "input": "/test simple@anotherbot mrrp",
- "error": false,
- "output": null
- }
- ]
-}
diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json
deleted file mode 100644
index 9782f8ec..00000000
--- a/event/cmdschema/testdata/commands/tail.json
+++ /dev/null
@@ -1,60 +0,0 @@
-{
- "$schema": "../commands.schema.json#",
- "spec": {
- "command": "tail",
- "source": "@testbot",
- "parameters": [
- {
- "key": "meow",
- "schema": {
- "schema_type": "primitive",
- "type": "string"
- }
- },
- {
- "key": "reason",
- "schema": {
- "schema_type": "primitive",
- "type": "string"
- },
- "optional": true
- },
- {
- "key": "woof",
- "schema": {
- "schema_type": "primitive",
- "type": "boolean"
- },
- "optional": true
- }
- ],
- "fi.mau.tail_parameter": "reason"
- },
- "tests": [
- {
- "name": "no tail or flag",
- "input": "/tail mrrp",
- "output": {
- "meow": "mrrp",
- "reason": ""
- }
- },
- {
- "name": "tail, no flag",
- "input": "/tail mrrp meow meow",
- "output": {
- "meow": "mrrp",
- "reason": "meow meow"
- }
- },
- {
- "name": "flag before tail",
- "input": "/tail mrrp --woof meow meow",
- "output": {
- "meow": "mrrp",
- "reason": "meow meow",
- "woof": true
- }
- }
- ]
-}
diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go
deleted file mode 100644
index eceea3d2..00000000
--- a/event/cmdschema/testdata/data.go
+++ /dev/null
@@ -1,14 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package testdata
-
-import (
- "embed"
-)
-
-//go:embed *
-var FS embed.FS
diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json
deleted file mode 100644
index 8f52b7f5..00000000
--- a/event/cmdschema/testdata/parse_quote.json
+++ /dev/null
@@ -1,30 +0,0 @@
-[
- {"name": "empty string", "input": "", "output": ["", "", false]},
- {"name": "single word", "input": "meow", "output": ["meow", "", false]},
- {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]},
- {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
- {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
- {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]},
- {"name": "only spaces", "input": " ", "output": ["", "", false]},
- {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]},
- {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]},
- {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
- {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]},
- {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]},
- {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]},
- {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]},
- {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]},
- {"name": "just opening quote", "input": "\"", "output": ["", "", true]},
- {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]},
- {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]},
- {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]},
- {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]},
- {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]},
- {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]},
- {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]},
- {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
- {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]},
- {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]},
- {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]},
- {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]}
-]
diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json
deleted file mode 100644
index 9f249116..00000000
--- a/event/cmdschema/testdata/parse_quote.schema.json
+++ /dev/null
@@ -1,46 +0,0 @@
-{
- "$schema": "https://json-schema.org/draft/2020-12/schema#",
- "$id": "parse_quote.schema.json",
- "title": "parseQuote test cases",
- "description": "Test cases for the parseQuoted function",
- "type": "array",
- "items": {
- "type": "object",
- "required": [
- "name",
- "input",
- "output"
- ],
- "properties": {
- "name": {
- "type": "string",
- "description": "Name of the test case"
- },
- "input": {
- "type": "string",
- "description": "Input string to be parsed"
- },
- "output": {
- "type": "array",
- "description": "Expected output of parsing: [first word, remaining text, was quoted]",
- "minItems": 3,
- "maxItems": 3,
- "prefixItems": [
- {
- "type": "string",
- "description": "First parsed word"
- },
- {
- "type": "string",
- "description": "Remaining text after the first word"
- },
- {
- "type": "boolean",
- "description": "Whether the first word was quoted"
- }
- ]
- }
- },
- "additionalProperties": false
- }
-}
diff --git a/event/content.go b/event/content.go
index 814aeec4..b56e35f2 100644
--- a/event/content.go
+++ b/event/content.go
@@ -18,7 +18,6 @@ import (
// This is used by Content.ParseRaw() for creating the correct type of struct.
var TypeMap = map[Type]reflect.Type{
StateMember: reflect.TypeOf(MemberEventContent{}),
- StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}),
StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}),
StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}),
StateRoomName: reflect.TypeOf(RoomNameEventContent{}),
@@ -39,9 +38,7 @@ var TypeMap = map[Type]reflect.Type{
StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}),
StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}),
StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}),
-
- StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
- StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+ StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}),
StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}),
StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}),
@@ -52,7 +49,6 @@ var TypeMap = map[Type]reflect.Type{
StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}),
StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}),
- StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}),
EventMessage: reflect.TypeOf(MessageEventContent{}),
EventSticker: reflect.TypeOf(MessageEventContent{}),
@@ -63,11 +59,8 @@ var TypeMap = map[Type]reflect.Type{
EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}),
EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}),
- BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
- BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
- BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
- BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}),
- BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}),
+ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
AccountDataRoomTags: reflect.TypeOf(TagEventContent{}),
AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}),
@@ -76,11 +69,9 @@ var TypeMap = map[Type]reflect.Type{
AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}),
AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}),
- EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
- EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
- EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
- EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}),
- BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}),
+ EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
+ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
+ EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
diff --git a/event/delayed.go b/event/delayed.go
deleted file mode 100644
index fefb62af..00000000
--- a/event/delayed.go
+++ /dev/null
@@ -1,70 +0,0 @@
-package event
-
-import (
- "encoding/json"
-
- "go.mau.fi/util/jsontime"
-
- "maunium.net/go/mautrix/id"
-)
-
-type ScheduledDelayedEvent struct {
- DelayID id.DelayID `json:"delay_id"`
- RoomID id.RoomID `json:"room_id"`
- Type Type `json:"type"`
- StateKey *string `json:"state_key,omitempty"`
- Delay int64 `json:"delay"`
- RunningSince jsontime.UnixMilli `json:"running_since"`
- Content Content `json:"content"`
-}
-
-func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) {
- evt := &Event{
- ID: eventID,
- RoomID: e.RoomID,
- Type: e.Type,
- StateKey: e.StateKey,
- Content: e.Content,
- Timestamp: ts.UnixMilli(),
- }
- return evt, evt.Content.ParseRaw(evt.Type)
-}
-
-type FinalisedDelayedEvent struct {
- DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"`
- Outcome DelayOutcome `json:"outcome"`
- Reason DelayReason `json:"reason"`
- Error json.RawMessage `json:"error,omitempty"`
- EventID id.EventID `json:"event_id,omitempty"`
- Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
-}
-
-type DelayStatus string
-
-var (
- DelayStatusScheduled DelayStatus = "scheduled"
- DelayStatusFinalised DelayStatus = "finalised"
-)
-
-type DelayAction string
-
-var (
- DelayActionSend DelayAction = "send"
- DelayActionCancel DelayAction = "cancel"
- DelayActionRestart DelayAction = "restart"
-)
-
-type DelayOutcome string
-
-var (
- DelayOutcomeSend DelayOutcome = "send"
- DelayOutcomeCancel DelayOutcome = "cancel"
-)
-
-type DelayReason string
-
-var (
- DelayReasonAction DelayReason = "action"
- DelayReasonError DelayReason = "error"
- DelayReasonDelay DelayReason = "delay"
-)
diff --git a/event/encryption.go b/event/encryption.go
index c60cb91a..cf9c2814 100644
--- a/event/encryption.go
+++ b/event/encryption.go
@@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
- return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString)
+ return id.InputNotJSONString
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}
@@ -132,9 +132,8 @@ type RoomKeyRequestEventContent struct {
type RequestedKeyInfo struct {
Algorithm id.Algorithm `json:"algorithm"`
RoomID id.RoomID `json:"room_id"`
- SessionID id.SessionID `json:"session_id"`
- // Deprecated: Matrix v1.3
SenderKey id.SenderKey `json:"sender_key"`
+ SessionID id.SessionID `json:"session_id"`
}
type RoomKeyWithheldCode string
diff --git a/event/events.go b/event/events.go
index 72c1e161..a763cc31 100644
--- a/event/events.go
+++ b/event/events.go
@@ -130,29 +130,36 @@ func (evt *Event) GetStateKey() string {
return ""
}
+type StrippedState struct {
+ Content Content `json:"content"`
+ Type Type `json:"type"`
+ StateKey string `json:"state_key"`
+ Sender id.UserID `json:"sender"`
+}
+
type Unsigned struct {
- PrevContent *Content `json:"prev_content,omitempty"`
- PrevSender id.UserID `json:"prev_sender,omitempty"`
- Membership Membership `json:"membership,omitempty"`
- ReplacesState id.EventID `json:"replaces_state,omitempty"`
- Age int64 `json:"age,omitempty"`
- TransactionID string `json:"transaction_id,omitempty"`
- Relations *Relations `json:"m.relations,omitempty"`
- RedactedBecause *Event `json:"redacted_because,omitempty"`
- InviteRoomState []*Event `json:"invite_room_state,omitempty"`
+ PrevContent *Content `json:"prev_content,omitempty"`
+ PrevSender id.UserID `json:"prev_sender,omitempty"`
+ Membership Membership `json:"membership,omitempty"`
+ ReplacesState id.EventID `json:"replaces_state,omitempty"`
+ Age int64 `json:"age,omitempty"`
+ TransactionID string `json:"transaction_id,omitempty"`
+ Relations *Relations `json:"m.relations,omitempty"`
+ RedactedBecause *Event `json:"redacted_because,omitempty"`
+ InviteRoomState []StrippedState `json:"invite_room_state,omitempty"`
BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"`
BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"`
BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"`
BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"`
- ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"`
- ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"`
+ MauSoftFailed bool `json:"fi.mau.soft_failed,omitempty"`
+ MauRejectionReason string `json:"fi.mau.rejection_reason,omitempty"`
}
func (us *Unsigned) IsEmpty() bool {
return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" &&
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() &&
- !us.ElementSoftFailed
+ !us.MauSoftFailed && us.MauRejectionReason == ""
}
diff --git a/event/member.go b/event/member.go
index 9956a36b..02b7cae9 100644
--- a/event/member.go
+++ b/event/member.go
@@ -7,6 +7,8 @@
package event
import (
+ "encoding/json"
+
"maunium.net/go/mautrix/id"
)
@@ -33,37 +35,22 @@ const (
// MemberEventContent represents the content of a m.room.member state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroommember
type MemberEventContent struct {
- Membership Membership `json:"membership"`
- AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
- Displayname string `json:"displayname,omitempty"`
- IsDirect bool `json:"is_direct,omitempty"`
- ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
- Reason string `json:"reason,omitempty"`
- JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"`
- MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
+ Membership Membership `json:"membership"`
+ AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
+ Displayname string `json:"displayname,omitempty"`
+ IsDirect bool `json:"is_direct,omitempty"`
+ ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"`
MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"`
}
-type SignedThirdPartyInvite struct {
- Token string `json:"token"`
- Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
- MXID string `json:"mxid"`
-}
-
type ThirdPartyInvite struct {
- DisplayName string `json:"display_name"`
- Signed SignedThirdPartyInvite `json:"signed"`
-}
-
-type ThirdPartyInviteEventContent struct {
- DisplayName string `json:"display_name"`
- KeyValidityURL string `json:"key_validity_url"`
- PublicKey id.Ed25519 `json:"public_key"`
- PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"`
-}
-
-type ThirdPartyInviteKey struct {
- KeyValidityURL string `json:"key_validity_url,omitempty"`
- PublicKey id.Ed25519 `json:"public_key"`
+ DisplayName string `json:"display_name"`
+ Signed struct {
+ Token string `json:"token"`
+ Signatures json.RawMessage `json:"signatures"`
+ MXID string `json:"mxid"`
+ }
}
diff --git a/event/message.go b/event/message.go
index 3fb3dc82..51403889 100644
--- a/event/message.go
+++ b/event/message.go
@@ -135,16 +135,11 @@ type MessageEventContent struct {
BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"`
- BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"`
BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"`
- BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"`
-
MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"`
MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"`
-
- MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"`
}
func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType {
@@ -276,25 +271,6 @@ func (m *Mentions) Has(userID id.UserID) bool {
return m != nil && slices.Contains(m.UserIDs, userID)
}
-func (m *Mentions) Merge(other *Mentions) *Mentions {
- if m == nil {
- return other
- } else if other == nil {
- return m
- }
- return &Mentions{
- UserIDs: slices.Concat(m.UserIDs, other.UserIDs),
- Room: m.Room || other.Room,
- }
-}
-
-type MSC4391BotCommandInputCustom[T any] struct {
- Command string `json:"command"`
- Arguments T `json:"arguments,omitempty"`
-}
-
-type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage]
-
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`
@@ -309,8 +285,7 @@ type FileInfo struct {
Blurhash string
AnoaBlurhash string
- MauGIF bool
- IsAnimated bool
+ MauGIF bool
Width int
Height int
@@ -327,8 +302,7 @@ type serializableFileInfo struct {
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
- MauGIF bool `json:"fi.mau.gif,omitempty"`
- IsAnimated bool `json:"is_animated,omitempty"`
+ MauGIF bool `json:"fi.mau.gif,omitempty"`
Width json.Number `json:"w,omitempty"`
Height json.Number `json:"h,omitempty"`
@@ -346,8 +320,7 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
ThumbnailFile: fileInfo.ThumbnailFile,
- MauGIF: fileInfo.MauGIF,
- IsAnimated: fileInfo.IsAnimated,
+ MauGIF: fileInfo.MauGIF,
Blurhash: fileInfo.Blurhash,
AnoaBlurhash: fileInfo.AnoaBlurhash,
@@ -378,7 +351,6 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
ThumbnailURL: sfi.ThumbnailURL,
ThumbnailFile: sfi.ThumbnailFile,
MauGIF: sfi.MauGIF,
- IsAnimated: sfi.IsAnimated,
Blurhash: sfi.Blurhash,
AnoaBlurhash: sfi.AnoaBlurhash,
}
diff --git a/event/message_test.go b/event/message_test.go
index c721df35..562a6622 100644
--- a/event/message_test.go
+++ b/event/message_test.go
@@ -33,7 +33,7 @@ const invalidMessageEvent = `{
func TestMessageEventContent__ParseInvalid(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(invalidMessageEvent), &evt)
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) {
assert.Equal(t, id.RoomID("!bar"), evt.RoomID)
err = evt.Content.ParseRaw(evt.Type)
- assert.Error(t, err)
+ assert.NotNil(t, err)
}
const messageEvent = `{
@@ -68,7 +68,7 @@ const messageEvent = `{
func TestMessageEventContent__ParseEdit(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(messageEvent), &evt)
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -110,7 +110,7 @@ const imageMessageEvent = `{
func TestMessageEventContent__ParseMedia(t *testing.T) {
var evt *event.Event
err := json.Unmarshal([]byte(imageMessageEvent), &evt)
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender)
assert.Equal(t, event.EventMessage, evt.Type)
@@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) {
content := evt.Content.Parsed.(*event.MessageEventContent)
assert.Equal(t, event.MsgImage, content.MsgType)
parsedURL, err := content.URL.Parse()
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL)
assert.Nil(t, content.NewContent)
assert.Equal(t, "image/png", content.GetInfo().MimeType)
@@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}`
func TestMessageEventContent__Marshal(t *testing.T) {
data, err := json.Marshal(parsedMessage)
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, expectedMarshalResult, string(data))
}
@@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun
func TestMessageEventContent__Marshal_Custom(t *testing.T) {
data, err := json.Marshal(customParsedMessage)
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, expectedCustomMarshalResult, string(data))
}
diff --git a/event/poll.go b/event/poll.go
index 9082f65e..47131a8f 100644
--- a/event/poll.go
+++ b/event/poll.go
@@ -35,7 +35,7 @@ type MSC1767Message struct {
}
type PollStartEventContent struct {
- RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
+ RelatesTo *RelatesTo `json:"m.relates_to"`
Mentions *Mentions `json:"m.mentions,omitempty"`
PollStart struct {
Kind string `json:"kind"`
diff --git a/event/powerlevels.go b/event/powerlevels.go
index 668eb6d3..2f4d4573 100644
--- a/event/powerlevels.go
+++ b/event/powerlevels.go
@@ -7,8 +7,6 @@
package event
import (
- "math"
- "slices"
"sync"
"go.mau.fi/util/ptr"
@@ -28,9 +26,6 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
- beeperEphemeralLock sync.RWMutex
- BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"`
-
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
@@ -39,12 +34,6 @@ type PowerLevelsEventContent struct {
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
-
- BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"`
-
- // This is not a part of power levels, it's added by mautrix-go internally in certain places
- // in order to detect creator power accurately.
- CreateEvent *Event `json:"-"`
}
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
@@ -56,7 +45,6 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
UsersDefault: pl.UsersDefault,
Events: maps.Clone(pl.Events),
EventsDefault: pl.EventsDefault,
- BeeperEphemeral: maps.Clone(pl.BeeperEphemeral),
StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
@@ -65,10 +53,6 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
KickPtr: ptr.Clone(pl.KickPtr),
BanPtr: ptr.Clone(pl.BanPtr),
RedactPtr: ptr.Clone(pl.RedactPtr),
-
- BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr),
-
- CreateEvent: pl.CreateEvent,
}
}
@@ -127,17 +111,7 @@ func (pl *PowerLevelsEventContent) StateDefault() int {
return 50
}
-func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int {
- if pl.BeeperEphemeralDefaultPtr != nil {
- return *pl.BeeperEphemeralDefaultPtr
- }
- return pl.EventsDefault
-}
-
func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
- if pl.isCreator(userID) {
- return math.MaxInt
- }
pl.usersLock.RLock()
defer pl.usersLock.RUnlock()
level, ok := pl.Users[userID]
@@ -147,19 +121,9 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
return level
}
-const maxPL = 1<<53 - 1
-
func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) {
pl.usersLock.Lock()
defer pl.usersLock.Unlock()
- if pl.isCreator(userID) {
- return
- }
- if level == math.MaxInt && maxPL < math.MaxInt {
- // Hack to avoid breaking on 32-bit systems (they're only slightly supported)
- x := int64(maxPL)
- level = int(x)
- }
if level == pl.UsersDefault {
delete(pl.Users, userID)
} else {
@@ -174,24 +138,9 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int)
return pl.EnsureUserLevelAs("", target, level)
}
-func (pl *PowerLevelsEventContent) createContent() *CreateEventContent {
- if pl.CreateEvent == nil {
- return &CreateEventContent{}
- }
- return pl.CreateEvent.Content.AsCreate()
-}
-
-func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool {
- cc := pl.createContent()
- return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID))
-}
-
func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool {
- if pl.isCreator(target) {
- return false
- }
existingLevel := pl.GetUserLevel(target)
- if actor != "" && !pl.isCreator(actor) {
+ if actor != "" {
actorLevel := pl.GetUserLevel(actor)
if actorLevel <= existingLevel || actorLevel < level {
return false
@@ -217,29 +166,6 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int {
return level
}
-func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int {
- pl.beeperEphemeralLock.RLock()
- defer pl.beeperEphemeralLock.RUnlock()
- level, ok := pl.BeeperEphemeral[eventType.String()]
- if !ok {
- return pl.BeeperEphemeralDefault()
- }
- return level
-}
-
-func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) {
- pl.beeperEphemeralLock.Lock()
- defer pl.beeperEphemeralLock.Unlock()
- if level == pl.BeeperEphemeralDefault() {
- delete(pl.BeeperEphemeral, eventType.String())
- } else {
- if pl.BeeperEphemeral == nil {
- pl.BeeperEphemeral = make(map[string]int)
- }
- pl.BeeperEphemeral[eventType.String()] = level
- }
-}
-
func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) {
pl.eventsLock.Lock()
defer pl.eventsLock.Unlock()
@@ -259,7 +185,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b
func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool {
existingLevel := pl.GetEventLevel(eventType)
- if actor != "" && !pl.isCreator(actor) {
+ if actor != "" {
actorLevel := pl.GetUserLevel(actor)
if existingLevel > actorLevel || level > actorLevel {
return false
diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go
deleted file mode 100644
index f5861583..00000000
--- a/event/powerlevels_ephemeral_test.go
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package event_test
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "maunium.net/go/mautrix/event"
-)
-
-func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) {
- pl := &event.PowerLevelsEventContent{
- EventsDefault: 45,
- }
-
- assert.Equal(t, 45, pl.BeeperEphemeralDefault())
-
- override := 60
- pl.BeeperEphemeralDefaultPtr = &override
- assert.Equal(t, 60, pl.BeeperEphemeralDefault())
-}
-
-func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) {
- pl := &event.PowerLevelsEventContent{
- EventsDefault: 25,
- }
- evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
-
- assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType))
-
- pl.SetBeeperEphemeralLevel(evtType, 50)
- assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType))
- require.NotNil(t, pl.BeeperEphemeral)
- assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()])
-
- pl.SetBeeperEphemeralLevel(evtType, 25)
- _, exists := pl.BeeperEphemeral[evtType.String()]
- assert.False(t, exists)
-}
-
-func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) {
- override := 70
- pl := &event.PowerLevelsEventContent{
- EventsDefault: 35,
- BeeperEphemeral: map[string]int{"com.example.ephemeral": 90},
- BeeperEphemeralDefaultPtr: &override,
- }
-
- cloned := pl.Clone()
- require.NotNil(t, cloned)
- require.NotNil(t, cloned.BeeperEphemeralDefaultPtr)
- assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr)
- assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"])
-
- cloned.BeeperEphemeral["com.example.ephemeral"] = 99
- *cloned.BeeperEphemeralDefaultPtr = 71
-
- assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"])
- assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr)
-}
diff --git a/event/reply.go b/event/reply.go
index 5f55bb80..9ae1c110 100644
--- a/event/reply.go
+++ b/event/reply.go
@@ -32,13 +32,12 @@ func TrimReplyFallbackText(text string) string {
}
func (content *MessageEventContent) RemoveReplyFallback() {
- if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML {
- origHTML := content.FormattedBody
- content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
- if content.FormattedBody != origHTML {
- content.Body = TrimReplyFallbackText(content.Body)
- content.replyFallbackRemoved = true
+ if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
+ if content.Format == FormatHTML {
+ content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
}
+ content.Body = TrimReplyFallbackText(content.Body)
+ content.replyFallbackRemoved = true
}
}
diff --git a/event/state.go b/event/state.go
index ace170a5..028691e1 100644
--- a/event/state.go
+++ b/event/state.go
@@ -8,10 +8,6 @@ package event
import (
"encoding/base64"
- "encoding/json"
- "slices"
-
- "go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/id"
)
@@ -56,40 +52,10 @@ type TopicEventContent struct {
// m.room.topic state event as described in [MSC3765].
//
// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765
-type ExtensibleTopic = ExtensibleTextContainer
-
-type ExtensibleTextContainer struct {
+type ExtensibleTopic struct {
Text []ExtensibleText `json:"m.text"`
}
-func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool {
- if c == nil || description == nil {
- return c == description
- }
- return slices.Equal(c.Text, description.Text)
-}
-
-func MakeExtensibleText(text string) *ExtensibleTextContainer {
- return &ExtensibleTextContainer{
- Text: []ExtensibleText{{
- Body: text,
- MimeType: "text/plain",
- }},
- }
-}
-
-func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer {
- return &ExtensibleTextContainer{
- Text: []ExtensibleText{{
- Body: plaintext,
- MimeType: "text/plain",
- }, {
- Body: html,
- MimeType: "text/html",
- }},
- }
-}
-
// ExtensibleText represents the contents of an m.text field.
type ExtensibleText struct {
MimeType string `json:"mimetype,omitempty"`
@@ -103,66 +69,39 @@ type TombstoneEventContent struct {
ReplacementRoom id.RoomID `json:"replacement_room"`
}
-func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID {
- if tec == nil {
- return ""
- }
- return tec.ReplacementRoom
-}
-
type Predecessor struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
}
-// Deprecated: use id.RoomVersion instead
-type RoomVersion = id.RoomVersion
+type RoomVersion string
-// Deprecated: use id.RoomVX constants instead
const (
- RoomV1 = id.RoomV1
- RoomV2 = id.RoomV2
- RoomV3 = id.RoomV3
- RoomV4 = id.RoomV4
- RoomV5 = id.RoomV5
- RoomV6 = id.RoomV6
- RoomV7 = id.RoomV7
- RoomV8 = id.RoomV8
- RoomV9 = id.RoomV9
- RoomV10 = id.RoomV10
- RoomV11 = id.RoomV11
- RoomV12 = id.RoomV12
+ RoomV1 RoomVersion = "1"
+ RoomV2 RoomVersion = "2"
+ RoomV3 RoomVersion = "3"
+ RoomV4 RoomVersion = "4"
+ RoomV5 RoomVersion = "5"
+ RoomV6 RoomVersion = "6"
+ RoomV7 RoomVersion = "7"
+ RoomV8 RoomVersion = "8"
+ RoomV9 RoomVersion = "9"
+ RoomV10 RoomVersion = "10"
+ RoomV11 RoomVersion = "11"
)
// CreateEventContent represents the content of a m.room.create state event.
// https://spec.matrix.org/v1.2/client-server-api/#mroomcreate
type CreateEventContent struct {
- Type RoomType `json:"type,omitempty"`
- Federate *bool `json:"m.federate,omitempty"`
- RoomVersion id.RoomVersion `json:"room_version,omitempty"`
- Predecessor *Predecessor `json:"predecessor,omitempty"`
-
- // Room v12+ only
- AdditionalCreators []id.UserID `json:"additional_creators,omitempty"`
+ Type RoomType `json:"type,omitempty"`
+ Federate *bool `json:"m.federate,omitempty"`
+ RoomVersion RoomVersion `json:"room_version,omitempty"`
+ Predecessor *Predecessor `json:"predecessor,omitempty"`
// Deprecated: use the event sender instead
Creator id.UserID `json:"creator,omitempty"`
}
-func (cec *CreateEventContent) GetPredecessor() (p Predecessor) {
- if cec != nil && cec.Predecessor != nil {
- p = *cec.Predecessor
- }
- return
-}
-
-func (cec *CreateEventContent) SupportsCreatorPower() bool {
- if cec == nil {
- return false
- }
- return cec.RoomVersion.PrivilegedRoomCreators()
-}
-
// JoinRule specifies how open a room is to new members.
// https://spec.matrix.org/v1.2/client-server-api/#mroomjoin_rules
type JoinRule string
@@ -238,8 +177,7 @@ type BridgeInfoSection struct {
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
ExternalURL string `json:"external_url,omitempty"`
- Receiver string `json:"fi.mau.receiver,omitempty"`
- MessageRequest bool `json:"com.beeper.message_request,omitempty"`
+ Receiver string `json:"fi.mau.receiver,omitempty"`
}
// BridgeEventContent represents the content of a m.bridge state event.
@@ -253,32 +191,6 @@ type BridgeEventContent struct {
BeeperRoomType string `json:"com.beeper.room_type,omitempty"`
BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"`
-
- TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"`
- TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"`
-}
-
-// DisappearingType represents the type of a disappearing message timer.
-type DisappearingType string
-
-const (
- DisappearingTypeNone DisappearingType = ""
- DisappearingTypeAfterRead DisappearingType = "after_read"
- DisappearingTypeAfterSend DisappearingType = "after_send"
-)
-
-type BeeperDisappearingTimer struct {
- Type DisappearingType `json:"type"`
- Timer jsontime.Milliseconds `json:"timer"`
-}
-
-type marshalableBeeperDisappearingTimer BeeperDisappearingTimer
-
-func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) {
- if bdt == nil || bdt.Type == DisappearingTypeNone {
- return []byte("{}"), nil
- }
- return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt))
}
type SpaceChildEventContent struct {
@@ -332,26 +244,12 @@ func (mpc *ModPolicyContent) EntityOrHash() string {
return mpc.Entity
}
+// Deprecated: MSC2716 has been abandoned
+type InsertionMarkerContent struct {
+ InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"`
+ Timestamp int64 `json:"com.beeper.timestamp,omitempty"`
+}
+
type ElementFunctionalMembersContent struct {
ServiceMembers []id.UserID `json:"service_members"`
}
-
-func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool {
- if slices.Contains(efmc.ServiceMembers, mxid) {
- return false
- }
- efmc.ServiceMembers = append(efmc.ServiceMembers, mxid)
- return true
-}
-
-type PolicyServerPublicKeys struct {
- Ed25519 id.Ed25519 `json:"ed25519,omitempty"`
-}
-
-type RoomPolicyEventContent struct {
- Via string `json:"via,omitempty"`
- PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"`
-
- // Deprecated, only for legacy use
- PublicKey id.Ed25519 `json:"public_key,omitempty"`
-}
diff --git a/event/type.go b/event/type.go
index 80b86728..591d598d 100644
--- a/event/type.go
+++ b/event/type.go
@@ -108,14 +108,13 @@ func (et *Type) IsCustom() bool {
func (et *Type) GuessClass() TypeClass {
switch et.Type {
- case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type,
+ case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type,
StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type,
StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type,
StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type,
- StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type,
- StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type:
+ StateInsertionMarker.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type:
return StateEventType
- case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type:
+ case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type:
return EphemeralEventType
case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type,
AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type,
@@ -128,7 +127,7 @@ func (et *Type) GuessClass() TypeClass {
InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type,
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type,
- EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type:
+ BeeperTranscription.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
@@ -178,7 +177,6 @@ var (
StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType}
StateGuestAccess = Type{"m.room.guest_access", StateEventType}
StateMember = Type{"m.room.member", StateEventType}
- StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType}
StatePowerLevels = Type{"m.room.power_levels", StateEventType}
StateRoomName = Type{"m.room.name", StateEventType}
StateTopic = Type{"m.room.topic", StateEventType}
@@ -195,9 +193,6 @@ var (
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
- StateRoomPolicy = Type{"m.room.policy", StateEventType}
- StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType}
-
StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType}
StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType}
StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType}
@@ -205,10 +200,11 @@ var (
StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType}
StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType}
+ // Deprecated: MSC2716 has been abandoned
+ StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
+
StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType}
StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType}
- StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType}
- StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType}
)
// Message events
@@ -237,24 +233,18 @@ var (
CallNegotiate = Type{"m.call.negotiate", MessageEventType}
CallHangup = Type{"m.call.hangup", MessageEventType}
- BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
- BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
- BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
- BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType}
- BeeperSendState = Type{"com.beeper.send_state", MessageEventType}
+ BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType}
EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType}
- EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType}
)
// Ephemeral events
var (
- EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
- EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
- EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
- EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType}
- BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType}
+ EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
+ EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
+ EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
)
// Account data events
diff --git a/example/main.go b/example/main.go
index 2bf4bef3..d8006d46 100644
--- a/example/main.go
+++ b/example/main.go
@@ -143,7 +143,7 @@ func main() {
if err != nil {
log.Error().Err(err).Msg("Failed to send event")
} else {
- log.Info().Stringer("event_id", resp.EventID).Msg("Event sent")
+ log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent")
}
}
cancelSync()
diff --git a/federation/client.go b/federation/client.go
index 183fb5d1..7c460d44 100644
--- a/federation/client.go
+++ b/federation/client.go
@@ -21,7 +21,6 @@ import (
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
- "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -30,8 +29,6 @@ type Client struct {
ServerName string
UserAgent string
Key *SigningKey
-
- ResponseSizeLimit int64
}
func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
@@ -39,16 +36,10 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien
HTTP: &http.Client{
Transport: NewServerResolvingTransport(cache),
Timeout: 120 * time.Second,
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- // Federation requests do not allow redirects.
- return http.ErrUseLastResponse
- },
},
UserAgent: mautrix.DefaultUserAgent,
ServerName: serverName,
Key: key,
-
- ResponseSizeLimit: mautrix.DefaultResponseSizeLimit,
}
}
@@ -89,7 +80,7 @@ type RespSendTransaction struct {
}
func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) {
- err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp)
+ err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp)
return
}
@@ -263,169 +254,6 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken
return
}
-type ReqMakeJoin struct {
- RoomID id.RoomID
- UserID id.UserID
- Via string
- SupportedVersions []id.RoomVersion
-}
-
-type RespMakeJoin struct {
- RoomVersion id.RoomVersion `json:"room_version"`
- Event PDU `json:"event"`
-}
-
-type ReqSendJoin struct {
- RoomID id.RoomID
- EventID id.EventID
- OmitMembers bool
- Event PDU
- Via string
-}
-
-type ReqSendKnock struct {
- RoomID id.RoomID
- EventID id.EventID
- Event PDU
- Via string
-}
-
-type RespSendJoin struct {
- AuthChain []PDU `json:"auth_chain"`
- Event PDU `json:"event"`
- MembersOmitted bool `json:"members_omitted"`
- ServersInRoom []string `json:"servers_in_room"`
- State []PDU `json:"state"`
-}
-
-type RespSendKnock struct {
- KnockRoomState []PDU `json:"knock_room_state"`
-}
-
-type ReqSendInvite struct {
- RoomID id.RoomID `json:"-"`
- UserID id.UserID `json:"-"`
- Event PDU `json:"event"`
- InviteRoomState []PDU `json:"invite_room_state"`
- RoomVersion id.RoomVersion `json:"room_version"`
-}
-
-type RespSendInvite struct {
- Event PDU `json:"event"`
-}
-
-type ReqMakeLeave struct {
- RoomID id.RoomID
- UserID id.UserID
- Via string
-}
-
-type ReqSendLeave struct {
- RoomID id.RoomID
- EventID id.EventID
- Event PDU
- Via string
-}
-
-type (
- ReqMakeKnock = ReqMakeJoin
- RespMakeKnock = RespMakeJoin
- RespMakeLeave = RespMakeJoin
-)
-
-func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) {
- versions := make([]string, len(req.SupportedVersions))
- for i, v := range req.SupportedVersions {
- versions[i] = string(v)
- }
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.Via,
- Method: http.MethodGet,
- Path: URLPath{"v1", "make_join", req.RoomID, req.UserID},
- Query: url.Values{"ver": versions},
- Authenticate: true,
- ResponseJSON: &resp,
- })
- return
-}
-
-func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) {
- versions := make([]string, len(req.SupportedVersions))
- for i, v := range req.SupportedVersions {
- versions[i] = string(v)
- }
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.Via,
- Method: http.MethodGet,
- Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID},
- Query: url.Values{"ver": versions},
- Authenticate: true,
- ResponseJSON: &resp,
- })
- return
-}
-
-func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) {
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.Via,
- Method: http.MethodPut,
- Path: URLPath{"v2", "send_join", req.RoomID, req.EventID},
- Query: url.Values{
- "omit_members": {strconv.FormatBool(req.OmitMembers)},
- },
- Authenticate: true,
- RequestJSON: req.Event,
- ResponseJSON: &resp,
- })
- return
-}
-
-func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) {
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.Via,
- Method: http.MethodPut,
- Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID},
- Authenticate: true,
- RequestJSON: req.Event,
- ResponseJSON: &resp,
- })
- return
-}
-
-func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) {
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.UserID.Homeserver(),
- Method: http.MethodPut,
- Path: URLPath{"v2", "invite", req.RoomID, req.UserID},
- Authenticate: true,
- RequestJSON: req,
- ResponseJSON: &resp,
- })
- return
-}
-
-func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) {
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.Via,
- Method: http.MethodGet,
- Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID},
- Authenticate: true,
- ResponseJSON: &resp,
- })
- return
-}
-
-func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) {
- _, _, err = c.MakeFullRequest(ctx, RequestParams{
- ServerName: req.Via,
- Method: http.MethodPut,
- Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID},
- Authenticate: true,
- RequestJSON: req.Event,
- })
- return
-}
-
type URLPath []any
func (fup URLPath) FullPath() []any {
@@ -477,27 +305,15 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b
WrappedError: err,
}
}
- if !params.DontReadBody {
- defer resp.Body.Close()
- }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
var body []byte
- if resp.StatusCode >= 300 {
+ if resp.StatusCode >= 400 {
body, err = mautrix.ParseErrorResponse(req, resp)
return body, resp, err
} else if params.ResponseJSON != nil || !params.DontReadBody {
- if resp.ContentLength > c.ResponseSizeLimit {
- return body, resp, mautrix.HTTPError{
- Request: req,
- Response: resp,
-
- Message: "not reading response",
- WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024),
- }
- }
- body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1))
- if err == nil && len(body) > int(c.ResponseSizeLimit) {
- err = mautrix.ErrBodyReadReachedLimit
- }
+ body, err = io.ReadAll(resp.Body)
if err != nil {
return body, resp, mautrix.HTTPError{
Request: req,
@@ -588,7 +404,7 @@ func (r *signableRequest) Verify(key id.SigningKey, sig string) error {
if err != nil {
return fmt.Errorf("failed to marshal data: %w", err)
}
- return signutil.VerifyJSONRaw(key, sig, message)
+ return VerifyJSONRaw(key, sig, message)
}
func (r *signableRequest) Sign(key *SigningKey) (string, error) {
diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go
deleted file mode 100644
index c72933c2..00000000
--- a/federation/eventauth/eventauth.go
+++ /dev/null
@@ -1,851 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package eventauth
-
-import (
- "encoding/json"
- "encoding/json/jsontext"
- "errors"
- "fmt"
- "slices"
- "strconv"
- "strings"
-
- "github.com/tidwall/gjson"
- "go.mau.fi/util/exgjson"
- "go.mau.fi/util/exstrings"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/federation/pdu"
- "maunium.net/go/mautrix/federation/signutil"
- "maunium.net/go/mautrix/id"
-)
-
-type AuthFailError struct {
- Index string
- Message string
- Wrapped error
-}
-
-func (afe AuthFailError) Error() string {
- if afe.Message != "" {
- return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message)
- } else if afe.Wrapped != nil {
- return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error())
- }
- return fmt.Sprintf("fail %s", afe.Index)
-}
-
-func (afe AuthFailError) Unwrap() error {
- return afe.Wrapped
-}
-
-var mFederatePath = exgjson.Path("m.federate")
-
-var (
- ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"}
- ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"}
- ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"}
- ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion}
- ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"}
- ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"}
-
- ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"}
- ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"}
- ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"}
- ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"}
-
- ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"}
- ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"}
- ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"}
- ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"}
- ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"}
- ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"}
- ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"}
- ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"}
- ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"}
-
- ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"}
-
- ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"}
- ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"}
- ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"}
- ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"}
- ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"}
- ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"}
- ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"}
- ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"}
- ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"}
- ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"}
- ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"}
- ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"}
- ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"}
- ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"}
- ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"}
- ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"}
- ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"}
- ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"}
- ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"}
- ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"}
- ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"}
- ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"}
- ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"}
- ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"}
- ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"}
- ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"}
- ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"}
- ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"}
-
- ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"}
-
- ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"}
-
- ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"}
-
- ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"}
-
- ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"}
- ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"}
- ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"}
- ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"}
- ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"}
- ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"}
- ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"}
-)
-
-func isRejected(evt *pdu.PDU) bool {
- return evt.InternalMeta.Rejected
-}
-
-type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error)
-
-func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error {
- if evt.Type == event.StateCreate.Type {
- // 1. If type is m.room.create:
- return authorizeCreate(roomVersion, evt)
- }
- var createEvt *pdu.PDU
- if roomVersion.RoomIDIsCreateEventID() {
- // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event,
- // with the sigil ! instead of $, reject.
- if len(evt.RoomID) != 44 {
- return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID))
- } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil {
- return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err)
- } else if len(createEvts) != 1 {
- return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID)
- } else if isRejected(createEvts[0]) {
- return ErrRejectedCreateEvent
- } else {
- createEvt = createEvts[0]
- }
- }
- authEvents, err := getEvents(evt.AuthEvents)
- if err != nil {
- return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err)
- }
- expectedAuthEvents := evt.AuthEventSelection(roomVersion)
- deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents))
- // 3. Considering the event’s auth_events:
- for i, ae := range authEvents {
- authEvtID := evt.AuthEvents[i]
- if ae == nil {
- return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID)
- } else if ae.StateKey == nil {
- // This approximately falls under rule 3.2.
- return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID)
- }
- key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey}
- if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound {
- // 3.1. If there are duplicate entries for a given type and state_key pair, reject.
- return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID)
- } else if !expectedAuthEvents.Has(key) {
- // 3.2. If there are entries whose type and state_key don’t match those specified by
- // the auth events selection algorithm described in the server specification, reject.
- return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey)
- } else if isRejected(ae) {
- // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject.
- return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID)
- } else if ae.RoomID != evt.RoomID {
- // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject.
- return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID)
- } else {
- deduplicator[key] = authEvtID
- }
- if ae.Type == event.StateCreate.Type {
- if createEvt == nil {
- createEvt = ae
- } else {
- // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+
- panic(fmt.Errorf("impossible case: multiple create events found in auth events"))
- }
- }
- }
- if createEvt == nil {
- // This comes either from auth_events or room_id depending on the room version.
- // The checks above make sure it's from the right source.
- return ErrNoCreateEvent
- }
- if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() {
- // 4. If the content of the m.room.create event in the room state has the property m.federate set to false,
- // and the sender domain of the event does not match the sender domain of the create event, reject.
- return ErrFederationDisabled
- }
- if evt.Type == event.StateMember.Type {
- // 5. If type is m.room.member:
- return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey)
- }
- senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
- if senderMembership != event.MembershipJoin {
- // 6. If the sender’s current membership state is not join, reject.
- return ErrNotInRoom
- }
- powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
- if err != nil {
- return err
- }
- senderPL := powerLevels.GetUserLevel(evt.Sender)
- if evt.Type == event.StateThirdPartyInvite.Type {
- // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level.
- if senderPL >= powerLevels.Invite() {
- return nil
- }
- return ErrInsufficientPowerForThirdPartyInvite
- }
- typeClass := event.MessageEventType
- if evt.StateKey != nil {
- typeClass = event.StateEventType
- }
- evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass})
- if evtLevel > senderPL {
- // 8. If the event type’s required power level is greater than the sender’s power level, reject.
- return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL)
- }
-
- if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() {
- // 9. If the event has a state_key that starts with an @ and does not match the sender, reject.
- return ErrMismatchingPrivateStateKey
- }
-
- if evt.Type == event.StatePowerLevels.Type {
- // 10. If type is m.room.power_levels:
- return authorizePowerLevels(roomVersion, evt, createEvt, authEvents)
- }
-
- // 11. Otherwise, allow.
- return nil
-}
-
-var ErrUserIDNotAString = errors.New("not a string")
-var ErrUserIDNotValid = errors.New("not a valid user ID")
-
-func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error {
- if userID.Type != gjson.String {
- return ErrUserIDNotAString
- }
- // In a future room version, user IDs will have stricter validation
- _, _, err := id.UserID(userID.Str).Parse()
- if err != nil {
- return ErrUserIDNotValid
- }
- return nil
-}
-
-func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error {
- if len(evt.PrevEvents) > 0 {
- // 1.1. If it has any prev_events, reject.
- return ErrCreateHasPrevEvents
- }
- if roomVersion.RoomIDIsCreateEventID() {
- if evt.RoomID != "" {
- // 1.2. If the event has a room_id, reject.
- return ErrCreateHasRoomID
- }
- } else {
- _, _, server := id.ParseCommonIdentifier(evt.RoomID)
- if server == "" || server != evt.Sender.Homeserver() {
- // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject.
- return ErrRoomIDDoesntMatchSender
- }
- }
- if !roomVersion.IsKnown() {
- // 1.3. If content.room_version is present and is not a recognised version, reject.
- return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion)
- }
- if roomVersion.PrivilegedRoomCreators() {
- additionalCreators := gjson.GetBytes(evt.Content, "additional_creators")
- if additionalCreators.Exists() {
- if !additionalCreators.IsArray() {
- return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators)
- }
- for i, item := range additionalCreators.Array() {
- // 1.4. If additional_creators is present in content and is not an array of strings
- // where each string passes the same user ID validation applied to sender, reject.
- if err := isValidUserID(roomVersion, item); err != nil {
- return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err)
- }
- }
- }
- }
- if roomVersion.CreatorInContent() {
- // 1.4. (v10 and below) If content has no creator property, reject.
- if !gjson.GetBytes(evt.Content, "creator").Exists() {
- return ErrMissingCreator
- }
- }
- // 1.5. Otherwise, allow.
- return nil
-}
-
-func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error {
- membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str)
- if evt.StateKey == nil {
- // 5.1. If there is no state_key property, or no membership property in content, reject.
- return ErrMemberNotState
- }
- authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str)
- if authorizedVia != "" {
- homeserver := authorizedVia.Homeserver()
- err := evt.VerifySignature(roomVersion, homeserver, getKey)
- if err != nil {
- // 5.2. If content has a join_authorised_via_users_server key:
- // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject.
- return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err)
- }
- }
- targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave"))
- senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave"))
- switch membership {
- case event.MembershipJoin:
- createEvtID, err := createEvt.GetEventID(roomVersion)
- if err != nil {
- return fmt.Errorf("failed to get create event ID: %w", err)
- }
- creator := createEvt.Sender.String()
- if roomVersion.CreatorInContent() {
- creator = gjson.GetBytes(evt.Content, "creator").Str
- }
- if len(evt.PrevEvents) == 1 &&
- len(evt.AuthEvents) <= 1 &&
- evt.PrevEvents[0] == createEvtID &&
- *evt.StateKey == creator {
- // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow.
- return nil
- }
- // Spec wart: this would make more sense before the check above.
- // Now you can set anyone as the sender of the first join.
- if evt.Sender.String() != *evt.StateKey {
- // 5.3.2. If the sender does not match state_key, reject.
- return ErrCantJoinOtherUser
- }
-
- if senderMembership == event.MembershipBan {
- // 5.3.3. If the sender is banned, reject.
- return ErrCantJoinBanned
- }
-
- joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
- switch joinRule {
- case event.JoinRuleKnock:
- if !roomVersion.Knocks() {
- return ErrInvalidJoinRule
- }
- fallthrough
- case event.JoinRuleInvite:
- // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join.
- if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
- return nil
- }
- return ErrCantJoinWithoutInvite
- case event.JoinRuleKnockRestricted:
- if !roomVersion.KnockRestricted() {
- return ErrInvalidJoinRule
- }
- fallthrough
- case event.JoinRuleRestricted:
- if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() {
- return ErrInvalidJoinRule
- }
- if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite {
- // 5.3.5.1. If membership state is join or invite, allow.
- return nil
- }
- powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
- if err != nil {
- return err
- }
- if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() {
- // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject.
- return ErrAuthoriserCantInvite
- }
- authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave)))
- if authorizerMembership != event.MembershipJoin {
- return ErrAuthoriserNotInRoom
- }
- // 5.3.5.3. Otherwise, allow.
- return nil
- case event.JoinRulePublic:
- // 5.3.6. If the join_rule is public, allow.
- return nil
- default:
- // 5.3.7. Otherwise, reject.
- return ErrInvalidJoinRule
- }
- case event.MembershipInvite:
- tpiVal := gjson.GetBytes(evt.Content, "third_party_invite")
- if tpiVal.Exists() {
- if targetPrevMembership == event.MembershipBan {
- return ErrThirdPartyInviteBanned
- }
- signed := tpiVal.Get("signed")
- mxid := signed.Get("mxid").Str
- token := signed.Get("token").Str
- if mxid == "" || token == "" {
- // 5.4.1.2. If content.third_party_invite does not have a signed property, reject.
- // 5.4.1.3. If signed does not have mxid and token properties, reject.
- return ErrThirdPartyInviteMissingFields
- }
- if mxid != *evt.StateKey {
- // 5.4.1.4. If mxid does not match state_key, reject.
- return ErrThirdPartyInviteMXIDMismatch
- }
- tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token)
- if tpiEvt == nil {
- // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject.
- return ErrThirdPartyInviteNotFound
- }
- if tpiEvt.Sender != evt.Sender {
- // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject.
- return ErrThirdPartyInviteSenderMismatch
- }
- var keys []id.Ed25519
- const ed25519Base64Len = 43
- oldPubKey := gjson.GetBytes(evt.Content, "public_key.token")
- if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len {
- keys = append(keys, id.Ed25519(oldPubKey.Str))
- }
- gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool {
- if key.Type != gjson.Number {
- return false
- }
- if value.Type == gjson.String && len(value.Str) == ed25519Base64Len {
- keys = append(keys, id.Ed25519(value.Str))
- }
- return true
- })
- rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str))
- var validated bool
- for _, key := range keys {
- if signutil.VerifyJSONAny(key, rawSigned) == nil {
- validated = true
- }
- }
- if validated {
- // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow.
- return nil
- }
- // 4.4.1.8. Otherwise, reject.
- return ErrThirdPartyInviteNotSigned
- }
- if senderMembership != event.MembershipJoin {
- // 5.4.2. If the sender’s current membership state is not join, reject.
- return ErrInviterNotInRoom
- }
- // 5.4.3. If target user’s current membership state is join or ban, reject.
- if targetPrevMembership == event.MembershipJoin {
- return ErrInviteTargetAlreadyInRoom
- } else if targetPrevMembership == event.MembershipBan {
- return ErrInviteTargetBanned
- }
- powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
- if err != nil {
- return err
- }
- if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() {
- // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow.
- return nil
- }
- // 5.4.5. Otherwise, reject.
- return ErrInsufficientPermissionForInvite
- case event.MembershipLeave:
- if evt.Sender.String() == *evt.StateKey {
- // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock.
- if senderMembership == event.MembershipInvite ||
- senderMembership == event.MembershipJoin ||
- (senderMembership == event.MembershipKnock && roomVersion.Knocks()) {
- return nil
- }
- return ErrCantLeaveWithoutBeingInRoom
- }
- if senderMembership != event.MembershipJoin {
- // 5.5.2. If the sender’s current membership state is not join, reject.
- return ErrCantKickWithoutBeingInRoom
- }
- powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
- if err != nil {
- return err
- }
- senderLevel := powerLevels.GetUserLevel(evt.Sender)
- if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() {
- // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject.
- return ErrInsufficientPermissionForUnban
- }
- if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
- // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow.
- return nil
- }
- // TODO separate errors for < kick and < target user level?
- // 5.5.5. Otherwise, reject.
- return ErrInsufficientPermissionForKick
- case event.MembershipBan:
- if senderMembership != event.MembershipJoin {
- // 5.6.1. If the sender’s current membership state is not join, reject.
- return ErrCantBanWithoutBeingInRoom
- }
- powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt)
- if err != nil {
- return err
- }
- senderLevel := powerLevels.GetUserLevel(evt.Sender)
- if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel {
- // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow.
- return nil
- }
- // 5.6.3. Otherwise, reject.
- return ErrInsufficientPermissionForBan
- case event.MembershipKnock:
- joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite"))
- validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock
- validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted
- if !validKnockRule && !validKnockRestrictedRule {
- // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject.
- return ErrNotKnockableRoom
- }
- if evt.Sender.String() != *evt.StateKey {
- // 5.7.2. If the sender does not match state_key, reject.
- return ErrCantKnockOtherUser
- }
- if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin {
- // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow.
- return nil
- }
- // 5.7.4. Otherwise, reject.
- return ErrCantKnockWhileInRoom
- default:
- // 5.8. Otherwise, the membership is unknown. Reject.
- return ErrUnknownMembership
- }
-}
-
-func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error {
- if roomVersion.ValidatePowerLevelInts() {
- for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
- res := gjson.GetBytes(evt.Content, key)
- if !res.Exists() {
- continue
- }
- if parseIntWithVersion(roomVersion, res) == nil {
- // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject.
- return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
- }
- }
- for _, key := range []string{"events", "notifications"} {
- obj := gjson.GetBytes(evt.Content, key)
- if !obj.Exists() {
- continue
- }
- // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject.
- if !obj.IsObject() {
- return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key)
- }
- var err error
- // 10.2. [...] are not an object with values that are integers, reject.
- obj.ForEach(func(innerKey, value gjson.Result) bool {
- if parseIntWithVersion(roomVersion, value) == nil {
- err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str)
- return false
- }
- return true
- })
- if err != nil {
- return err
- }
- }
- }
- var creators []id.UserID
- if roomVersion.PrivilegedRoomCreators() {
- creators = append(creators, createEvt.Sender)
- gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool {
- creators = append(creators, id.UserID(value.Str))
- return true
- })
- }
- users := gjson.GetBytes(evt.Content, "users")
- if users.Exists() {
- if !users.IsObject() {
- // 10.3. If the users property in content is not an object [...], reject.
- return fmt.Errorf("%w users", ErrTopLevelPLNotInteger)
- }
- var err error
- users.ForEach(func(key, value gjson.Result) bool {
- if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil {
- // 10.3. [...] is not an object with keys that are valid user IDs [...], reject.
- err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr)
- return false
- }
- if parseIntWithVersion(roomVersion, value) == nil {
- // 10.3. [...] is not an object [...] with values that are integers, reject.
- err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str)
- return false
- }
- // creators is only filled if the room version has privileged room creators
- if slices.Contains(creators, id.UserID(key.Str)) {
- // 10.4. If the users property in content contains the sender of the m.room.create event or any of
- // the additional_creators array (if present) from the content of the m.room.create event, reject.
- err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str)
- return false
- }
- return true
- })
- if err != nil {
- return err
- }
- }
- oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "")
- if oldPL == nil {
- // 10.5. If there is no previous m.room.power_levels event in the room, allow.
- return nil
- }
- if slices.Contains(creators, evt.Sender) {
- // Skip remaining checks for creators
- return nil
- }
- senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String())))
- if senderPLPtr == nil {
- senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default"))
- if senderPLPtr == nil {
- senderPLPtr = ptr.Ptr(0)
- }
- }
- for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} {
- oldVal := gjson.GetBytes(oldPL.Content, key)
- newVal := gjson.GetBytes(evt.Content, key)
- if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil {
- return err
- }
- }
- if err := allowPowerChangeMap(
- roomVersion, *senderPLPtr, "events", "",
- gjson.GetBytes(oldPL.Content, "events"),
- gjson.GetBytes(evt.Content, "events"),
- ); err != nil {
- return err
- }
- if err := allowPowerChangeMap(
- roomVersion, *senderPLPtr, "notifications", "",
- gjson.GetBytes(oldPL.Content, "notifications"),
- gjson.GetBytes(evt.Content, "notifications"),
- ); err != nil {
- return err
- }
- if err := allowPowerChangeMap(
- roomVersion, *senderPLPtr, "users", evt.Sender.String(),
- gjson.GetBytes(oldPL.Content, "users"),
- gjson.GetBytes(evt.Content, "users"),
- ); err != nil {
- return err
- }
- return nil
-}
-
-func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) {
- old.ForEach(func(key, value gjson.Result) bool {
- newVal := new.Get(exgjson.Path(key.Str))
- err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal)
- if err == nil && ownID != "" && key.Str != ownID {
- parsedOldVal := parseIntWithVersion(roomVersion, value)
- parsedNewVal := parseIntWithVersion(roomVersion, newVal)
- if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal {
- err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal)
- }
- }
- return err == nil
- })
- if err != nil {
- return
- }
- new.ForEach(func(key, value gjson.Result) bool {
- err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value)
- return err == nil
- })
- return
-}
-
-func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error {
- oldVal := parseIntWithVersion(roomVersion, old)
- newVal := parseIntWithVersion(roomVersion, new)
- if oldVal == nil {
- if newVal == nil || *newVal <= maxVal {
- return nil
- }
- } else if newVal == nil {
- if *oldVal <= maxVal {
- return nil
- }
- } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) {
- return nil
- }
- return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal)
-}
-
-func stringifyForError(val gjson.Result) string {
- if !val.Exists() {
- return "null"
- }
- return val.Raw
-}
-
-func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU {
- for _, evt := range events {
- if evt.Type == evtType && *evt.StateKey == stateKey {
- return evt
- }
- }
- return nil
-}
-
-func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T {
- return reader(findEvent(events, evtType, stateKey))
-}
-
-func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string {
- return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string {
- if evt == nil {
- return defVal
- }
- res := gjson.GetBytes(evt.Content, fieldPath)
- if res.Type != gjson.String {
- return defVal
- }
- return res.Str
- })
-}
-
-func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) {
- var err error
- powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent {
- if evt == nil {
- return nil
- }
- content := evt.Content
- out := &event.PowerLevelsEventContent{}
- if !roomVersion.ValidatePowerLevelInts() {
- safeParsePowerLevels(content, out)
- } else {
- err = json.Unmarshal(content, out)
- }
- return out
- })
- if err != nil {
- // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+
- return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
- }
- if roomVersion.PrivilegedRoomCreators() {
- if powerLevels == nil {
- powerLevels = &event.PowerLevelsEventContent{}
- }
- powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion)
- if err != nil {
- return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
- }
- err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type)
- if err != nil {
- return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err)
- }
- } else if powerLevels == nil {
- powerLevels = &event.PowerLevelsEventContent{
- Users: map[id.UserID]int{
- createEvt.Sender: 100,
- },
- }
- }
- return powerLevels, nil
-}
-
-func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int {
- if roomVersion.ValidatePowerLevelInts() {
- if val.Type != gjson.Number {
- return nil
- }
- return ptr.Ptr(int(val.Int()))
- }
- return parsePythonInt(val)
-}
-
-func parsePythonInt(val gjson.Result) *int {
- switch val.Type {
- case gjson.True:
- return ptr.Ptr(1)
- case gjson.False:
- return ptr.Ptr(0)
- case gjson.Number:
- return ptr.Ptr(int(val.Int()))
- case gjson.String:
- // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand
- num, err := strconv.Atoi(strings.TrimSpace(val.Str))
- if err != nil {
- return nil
- }
- return &num
- default:
- // Python int() doesn't accept nulls, arrays or dicts
- return nil
- }
-}
-
-func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) {
- *into = event.PowerLevelsEventContent{
- Users: make(map[id.UserID]int),
- UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))),
- Events: make(map[string]int),
- EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))),
- Notifications: nil, // irrelevant for event auth
- StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")),
- InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")),
- KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")),
- BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")),
- RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")),
- }
- gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool {
- if key.Type != gjson.String {
- return false
- }
- val := parsePythonInt(value)
- if val != nil {
- into.Events[key.Str] = *val
- }
- return true
- })
- gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool {
- if key.Type != gjson.String {
- return false
- }
- val := parsePythonInt(value)
- if val == nil {
- return false
- }
- userID := id.UserID(key.Str)
- if _, _, err := userID.Parse(); err != nil {
- return false
- }
- into.Users[userID] = *val
- return true
- })
-}
diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go
deleted file mode 100644
index d316f3c8..00000000
--- a/federation/eventauth/eventauth_internal_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright (c) 2026 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package eventauth
-
-import (
- "strings"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
-)
-
-type pythonIntTest struct {
- Name string
- Input string
- Expected int64
-}
-
-var pythonIntTests = []pythonIntTest{
- {"True", `true`, 1},
- {"False", `false`, 0},
- {"SmallFloat", `3.1415`, 3},
- {"SmallFloatRoundDown", `10.999999999999999`, 10},
- {"SmallFloatRoundUp", `10.9999999999999999`, 11},
- {"BigFloatRoundDown", `1000000.9999999999`, 1000000},
- {"BigFloatRoundUp", `1000000.99999999999`, 1000001},
- {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992},
- {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994},
- {"Int64", `9223372036854775807`, 9223372036854775807},
- {"Int64String", `"9223372036854775807"`, 9223372036854775807},
- {"String", `"123"`, 123},
- {"InvalidFloatInString", `"123.456"`, 0},
- {"StringWithPlusSign", `"+123"`, 123},
- {"StringWithMinusSign", `"-123"`, -123},
- {"StringWithSpaces", `" 123 "`, 123},
- {"StringWithSpacesAndSign", `" -123 "`, -123},
- //{"StringWithUnderscores", `"123_456"`, 123456},
- //{"StringWithUnderscores", `"123_456"`, 123456},
- {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0},
- {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0},
- {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0},
- {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0},
- {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0},
- //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456},
-}
-
-func TestParsePythonInt(t *testing.T) {
- for _, test := range pythonIntTests {
- t.Run(test.Name, func(t *testing.T) {
- output := parsePythonInt(gjson.Parse(test.Input))
- if strings.HasPrefix(test.Name, "Invalid") {
- assert.Nil(t, output)
- } else {
- require.NotNil(t, output)
- assert.Equal(t, int(test.Expected), *output)
- }
- })
- }
-}
diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go
deleted file mode 100644
index e3c5cd76..00000000
--- a/federation/eventauth/eventauth_test.go
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package eventauth_test
-
-import (
- "embed"
- "encoding/json/jsontext"
- "encoding/json/v2"
- "errors"
- "io"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
- "go.mau.fi/util/exerrors"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix/federation/eventauth"
- "maunium.net/go/mautrix/federation/pdu"
- "maunium.net/go/mautrix/id"
-)
-
-//go:embed *.jsonl
-var data embed.FS
-
-type eventMap map[id.EventID]*pdu.PDU
-
-func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) {
- output := make([]*pdu.PDU, len(ids))
- for i, evtID := range ids {
- output[i] = em[evtID]
- }
- return output, nil
-}
-
-func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) {
- return "", time.Time{}, nil
-}
-
-func TestAuthorize(t *testing.T) {
- files := exerrors.Must(data.ReadDir("."))
- for _, file := range files {
- t.Run(file.Name(), func(t *testing.T) {
- decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name())))
- events := make(eventMap)
- var roomVersion *id.RoomVersion
- for i := 1; ; i++ {
- var evt *pdu.PDU
- err := json.UnmarshalDecode(decoder, &evt)
- if errors.Is(err, io.EOF) {
- break
- }
- require.NoError(t, err)
- if roomVersion == nil {
- require.Equal(t, evt.Type, "m.room.create")
- roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str))
- }
- expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str
- evtID, err := evt.GetEventID(*roomVersion)
- require.NoError(t, err)
- require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i)
-
- // TODO allow redacted events
- assert.True(t, evt.VerifyContentHash(), i)
-
- events[evtID] = evt
- err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey)
- if err != nil {
- evt.InternalMeta.Rejected = true
- }
- // TODO allow testing intentionally rejected events
- assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type)
- }
- })
- }
-
-}
diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl
deleted file mode 100644
index 2b751de3..00000000
--- a/federation/eventauth/testroom-v12-success.jsonl
+++ /dev/null
@@ -1,21 +0,0 @@
-{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}}
-{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}}
-{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
-{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}}
-{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}}
-{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}}
-{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
-{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}}
-{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}}
-{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}}
-{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}}
-{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}}
-{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}}
-{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}}
-{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}}
-{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}}
-{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}}
-{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}}
-{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}}
-{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}}
-{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}}
diff --git a/federation/keyserver.go b/federation/keyserver.go
index d32ba5cf..b0faf8fb 100644
--- a/federation/keyserver.go
+++ b/federation/keyserver.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -12,13 +12,9 @@ import (
"strconv"
"time"
- "github.com/rs/zerolog"
- "github.com/rs/zerolog/hlog"
- "go.mau.fi/util/exerrors"
+ "github.com/gorilla/mux"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
- "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
@@ -55,25 +51,19 @@ type KeyServer struct {
}
// Register registers the key server endpoints to the given router.
-func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) {
- r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown)
- r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion)
- keyRouter := http.NewServeMux()
- keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey)
- keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys)
- keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys)
- errorBodies := exhttp.ErrorBodies{
- NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
- MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
- }
- r.Handle("/_matrix/key/", exhttp.ApplyMiddleware(
- keyRouter,
- exhttp.StripPrefix("/_matrix/key"),
- hlog.NewHandler(log),
- hlog.RequestIDHandler("request_id", "Request-Id"),
- requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
- exhttp.HandleErrors(errorBodies),
- ))
+func (ks *KeyServer) Register(r *mux.Router) {
+ r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
+ r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
+ keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
+ keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
+ keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
+ keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
+ keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w)
+ })
+ keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w)
+ })
}
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
@@ -167,7 +157,7 @@ type GetQueryKeysResponse struct {
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
- serverName := r.PathValue("serverName")
+ serverName := mux.Vars(r)["serverName"]
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {
diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go
deleted file mode 100644
index 16706fe5..00000000
--- a/federation/pdu/auth.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu
-
-import (
- "slices"
-
- "github.com/tidwall/gjson"
- "go.mau.fi/util/exgjson"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-type StateKey struct {
- Type string
- StateKey string
-}
-
-var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token")
-
-type AuthEventSelection []StateKey
-
-func (aes *AuthEventSelection) Add(evtType, stateKey string) {
- key := StateKey{Type: evtType, StateKey: stateKey}
- if !aes.Has(key) {
- *aes = append(*aes, key)
- }
-}
-
-func (aes *AuthEventSelection) Has(key StateKey) bool {
- return slices.Contains(*aes, key)
-}
-
-func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) {
- if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
- return AuthEventSelection{}
- }
- keys = make(AuthEventSelection, 0, 3)
- if !roomVersion.RoomIDIsCreateEventID() {
- keys.Add(event.StateCreate.Type, "")
- }
- keys.Add(event.StatePowerLevels.Type, "")
- keys.Add(event.StateMember.Type, pdu.Sender.String())
- if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
- keys.Add(event.StateMember.Type, *pdu.StateKey)
- membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
- if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
- keys.Add(event.StateJoinRules.Type, "")
- }
- if membership == event.MembershipInvite {
- thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
- if thirdPartyInviteToken != "" {
- keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
- }
- }
- if membership == event.MembershipJoin && roomVersion.RestrictedJoins() {
- authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str
- if authorizedVia != "" {
- keys.Add(event.StateMember.Type, authorizedVia)
- }
- }
- }
- return
-}
diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go
deleted file mode 100644
index 38ef83e9..00000000
--- a/federation/pdu/hash.go
+++ /dev/null
@@ -1,118 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu
-
-import (
- "crypto/hmac"
- "crypto/sha256"
- "encoding/base64"
- "fmt"
-
- "github.com/tidwall/gjson"
-
- "maunium.net/go/mautrix/id"
-)
-
-func (pdu *PDU) CalculateContentHash() ([32]byte, error) {
- if pdu == nil {
- return [32]byte{}, ErrPDUIsNil
- }
- pduClone := pdu.Clone()
- pduClone.Signatures = nil
- pduClone.Unsigned = nil
- pduClone.Hashes = nil
- rawJSON, err := marshalCanonical(pduClone)
- if err != nil {
- return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
- }
- return sha256.Sum256(rawJSON), nil
-}
-
-func (pdu *PDU) FillContentHash() error {
- if pdu == nil {
- return ErrPDUIsNil
- } else if pdu.Hashes != nil {
- return nil
- } else if hash, err := pdu.CalculateContentHash(); err != nil {
- return err
- } else {
- pdu.Hashes = &Hashes{SHA256: hash[:]}
- return nil
- }
-}
-
-func (pdu *PDU) VerifyContentHash() bool {
- if pdu == nil || pdu.Hashes == nil {
- return false
- }
- calculatedHash, err := pdu.CalculateContentHash()
- if err != nil {
- return false
- }
- return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
-}
-
-func (pdu *PDU) GetRoomID() (id.RoomID, error) {
- if pdu == nil {
- return "", ErrPDUIsNil
- } else if pdu.Type != "m.room.create" {
- return "", fmt.Errorf("room ID can only be calculated for m.room.create events")
- } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() {
- return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion)
- } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil {
- return "", fmt.Errorf("failed to calculate event ID: %w", err)
- } else {
- return id.RoomID(evtID), nil
- }
-}
-
-var UseInternalMetaForGetEventID = false
-
-func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
- if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" {
- return pdu.InternalMeta.EventID, nil
- }
- return pdu.calculateEventID(roomVersion, '$')
-}
-
-func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
- if pdu == nil {
- return [32]byte{}, ErrPDUIsNil
- }
- if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
- if err := pdu.FillContentHash(); err != nil {
- return [32]byte{}, err
- }
- }
- rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
- if err != nil {
- return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
- }
- return sha256.Sum256(rawJSON), nil
-}
-
-func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) {
- referenceHash, err := pdu.GetReferenceHash(roomVersion)
- if err != nil {
- return "", err
- }
- eventID := make([]byte, 44)
- eventID[0] = prefix
- switch roomVersion.EventIDFormat() {
- case id.EventIDFormatCustom:
- return "", fmt.Errorf("*pdu.PDU can only be used for room v3+")
- case id.EventIDFormatBase64:
- base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:])
- case id.EventIDFormatURLSafeBase64:
- base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:])
- default:
- return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat())
- }
- return id.EventID(eventID), nil
-}
diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go
deleted file mode 100644
index 17417e12..00000000
--- a/federation/pdu/hash_test.go
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu_test
-
-import (
- "encoding/base64"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "go.mau.fi/util/exerrors"
-)
-
-func TestPDU_CalculateContentHash(t *testing.T) {
- for _, test := range testPDUs {
- if test.redacted {
- continue
- }
- t.Run(test.name, func(t *testing.T) {
- parsed := parsePDU(test.pdu)
- contentHash := exerrors.Must(parsed.CalculateContentHash())
- assert.Equal(
- t,
- base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
- base64.RawStdEncoding.EncodeToString(contentHash[:]),
- )
- })
- }
-}
-
-func TestPDU_VerifyContentHash(t *testing.T) {
- for _, test := range testPDUs {
- if test.redacted {
- continue
- }
- t.Run(test.name, func(t *testing.T) {
- parsed := parsePDU(test.pdu)
- assert.True(t, parsed.VerifyContentHash())
- })
- }
-}
-
-func TestPDU_GetEventID(t *testing.T) {
- for _, test := range testPDUs {
- t.Run(test.name, func(t *testing.T) {
- gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion))
- assert.Equal(t, test.eventID, gotEventID)
- })
- }
-}
diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go
deleted file mode 100644
index 17db6995..00000000
--- a/federation/pdu/pdu.go
+++ /dev/null
@@ -1,156 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu
-
-import (
- "bytes"
- "crypto/ed25519"
- "encoding/json/jsontext"
- "encoding/json/v2"
- "errors"
- "fmt"
- "strings"
- "time"
-
- "github.com/tidwall/gjson"
- "go.mau.fi/util/jsonbytes"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix/crypto/canonicaljson"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU.
-//
-// The input time is the timestamp of the event. The function should attempt to fetch a key that is
-// valid at or after this time, but if that is not possible, the latest available key should be
-// returned without an error. The verify function will do its own validity checking based on the
-// returned valid until timestamp.
-type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error)
-
-type AnyPDU interface {
- GetRoomID() (id.RoomID, error)
- GetEventID(roomVersion id.RoomVersion) (id.EventID, error)
- GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error)
- CalculateContentHash() ([32]byte, error)
- FillContentHash() error
- VerifyContentHash() bool
- Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error
- VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error
- ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error)
- AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection)
-}
-
-var (
- _ AnyPDU = (*PDU)(nil)
- _ AnyPDU = (*RoomV1PDU)(nil)
-)
-
-type InternalMeta struct {
- EventID id.EventID `json:"event_id,omitempty"`
- Rejected bool `json:"rejected,omitempty"`
- Extra map[string]any `json:",unknown"`
-}
-
-type PDU struct {
- AuthEvents []id.EventID `json:"auth_events"`
- Content jsontext.Value `json:"content"`
- Depth int64 `json:"depth"`
- Hashes *Hashes `json:"hashes,omitzero"`
- OriginServerTS int64 `json:"origin_server_ts"`
- PrevEvents []id.EventID `json:"prev_events"`
- Redacts *id.EventID `json:"redacts,omitzero"`
- RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events
- Sender id.UserID `json:"sender"`
- Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
- StateKey *string `json:"state_key,omitzero"`
- Type string `json:"type"`
- Unsigned jsontext.Value `json:"unsigned,omitzero"`
- InternalMeta InternalMeta `json:"-"`
-
- Unknown jsontext.Value `json:",unknown"`
-
- // Deprecated legacy fields
- DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
- DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
- DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
-}
-
-var ErrPDUIsNil = errors.New("PDU is nil")
-
-type Hashes struct {
- SHA256 jsonbytes.UnpaddedBytes `json:"sha256"`
-
- Unknown jsontext.Value `json:",unknown"`
-}
-
-func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
- if pdu.Type == "m.room.create" && roomVersion == "" {
- roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str)
- }
- evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
- if pdu.StateKey != nil {
- evtType.Class = event.StateEventType
- }
- eventID, err := pdu.GetEventID(roomVersion)
- if err != nil {
- return nil, err
- }
- roomID := pdu.RoomID
- if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() {
- roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1))
- }
- evt := &event.Event{
- StateKey: pdu.StateKey,
- Sender: pdu.Sender,
- Type: evtType,
- Timestamp: pdu.OriginServerTS,
- ID: eventID,
- RoomID: roomID,
- Redacts: ptr.Val(pdu.Redacts),
- }
- err = json.Unmarshal(pdu.Content, &evt.Content)
- if err != nil {
- return nil, fmt.Errorf("failed to unmarshal content: %w", err)
- }
- return evt, nil
-}
-
-func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) {
- if signature == "" {
- return
- }
- if pdu.Signatures == nil {
- pdu.Signatures = make(map[string]map[id.KeyID]string)
- }
- if _, ok := pdu.Signatures[serverName]; !ok {
- pdu.Signatures[serverName] = make(map[id.KeyID]string)
- }
- pdu.Signatures[serverName][keyID] = signature
-}
-
-func marshalCanonical(data any) (jsontext.Value, error) {
- marshaledBytes, err := json.Marshal(data)
- if err != nil {
- return nil, err
- }
- marshaled := jsontext.Value(marshaledBytes)
- err = marshaled.Canonicalize()
- if err != nil {
- return nil, err
- }
- check := canonicaljson.CanonicalJSONAssumeValid(marshaled)
- if !bytes.Equal(marshaled, check) {
- fmt.Println(string(marshaled))
- fmt.Println(string(check))
- return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled))
- }
- return marshaled, nil
-}
diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go
deleted file mode 100644
index 59d7c3a6..00000000
--- a/federation/pdu/pdu_test.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu_test
-
-import (
- "encoding/json/v2"
- "time"
-
- "go.mau.fi/util/exerrors"
-
- "maunium.net/go/mautrix/federation/pdu"
- "maunium.net/go/mautrix/id"
-)
-
-type serverKey struct {
- key id.SigningKey
- validUntilTS time.Time
-}
-
-type serverDetails struct {
- serverName string
- keys map[id.KeyID]serverKey
-}
-
-func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
- if serverName != sd.serverName {
- return "", time.Time{}, nil
- }
- key, ok := sd.keys[keyID]
- if ok {
- return key.key, key.validUntilTS, nil
- }
- return "", time.Time{}, nil
-}
-
-var mauniumNet = serverDetails{
- serverName: "maunium.net",
- keys: map[id.KeyID]serverKey{
- "ed25519:a_xxeS": {
- key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg",
- validUntilTS: time.Now(),
- },
- },
-}
-var envsNet = serverDetails{
- serverName: "envs.net",
- keys: map[id.KeyID]serverKey{
- "ed25519:a_zIqy": {
- key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0",
- validUntilTS: time.UnixMilli(1722360538068),
- },
- "ed25519:wuJyKT": {
- key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0",
- validUntilTS: time.Now(),
- },
- },
-}
-var matrixOrg = serverDetails{
- serverName: "matrix.org",
- keys: map[id.KeyID]serverKey{
- "ed25519:auto": {
- key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw",
- validUntilTS: time.UnixMilli(1576767829750),
- },
- "ed25519:a_RXGa": {
- key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ",
- validUntilTS: time.Now(),
- },
- },
-}
-var continuwuityOrg = serverDetails{
- serverName: "continuwuity.org",
- keys: map[id.KeyID]serverKey{
- "ed25519:PwHlNsFu": {
- key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI",
- validUntilTS: time.Now(),
- },
- },
-}
-var novaAstraltechOrg = serverDetails{
- serverName: "nova.astraltech.org",
- keys: map[id.KeyID]serverKey{
- "ed25519:a_afpo": {
- key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o",
- validUntilTS: time.Now(),
- },
- },
-}
-
-type testPDU struct {
- name string
- pdu string
- eventID id.EventID
- roomVersion id.RoomVersion
- redacted bool
- serverDetails
-}
-
-var roomV4MessageTestPDU = testPDU{
- name: "m.room.message in v4 room",
- pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`,
- eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA",
- roomVersion: id.RoomV4,
- serverDetails: mauniumNet,
-}
-
-var roomV12MessageTestPDU = testPDU{
- name: "m.room.message in v12 room",
- pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`,
- eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc",
- roomVersion: id.RoomV12,
- serverDetails: mauniumNet,
-}
-
-var testPDUs = []testPDU{roomV4MessageTestPDU, {
- name: "m.room.message in v5 room",
- pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`,
- eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA",
- roomVersion: id.RoomV5,
- serverDetails: mauniumNet,
-}, {
- name: "m.room.message in v10 room",
- pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`,
- eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY",
- roomVersion: id.RoomV10,
- serverDetails: mauniumNet,
-}, {
- name: "m.room.message in v11 room",
- pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`,
- eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`,
- roomVersion: id.RoomV11,
- serverDetails: mauniumNet,
-}, roomV12MessageTestPDU, {
- name: "m.room.create in v4 room",
- pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`,
- eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY",
- roomVersion: id.RoomV4,
- serverDetails: matrixOrg,
-}, {
- name: "m.room.create in v10 room",
- pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`,
- eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ",
- roomVersion: id.RoomV10,
- serverDetails: envsNet,
-}, {
- name: "m.room.create in v12 room",
- pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`,
- eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
- roomVersion: id.RoomV12,
- serverDetails: mauniumNet,
-}, {
- name: "m.room.member in v4 room",
- pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`,
- eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo",
- roomVersion: id.RoomV4,
- serverDetails: mauniumNet,
-}, {
- name: "m.room.member in v10 room",
- pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`,
- eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs",
- roomVersion: id.RoomV10,
- serverDetails: mauniumNet,
-}, {
- name: "m.room.member of creator in v12 room",
- pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`,
- eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg",
- roomVersion: id.RoomV12,
- serverDetails: mauniumNet,
-}, {
- name: "custom message event in v4 room",
- pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`,
- eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0",
- roomVersion: id.RoomV4,
- serverDetails: mauniumNet,
-}, {
- name: "redacted m.room.member event in v11 room with 2 signatures",
- pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`,
- eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ",
- roomVersion: id.RoomV11,
- serverDetails: novaAstraltechOrg,
- redacted: true,
-}}
-
-func parsePDU(pdu string) (out *pdu.PDU) {
- exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
- return
-}
diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go
deleted file mode 100644
index d7ee0c15..00000000
--- a/federation/pdu/redact.go
+++ /dev/null
@@ -1,111 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu
-
-import (
- "encoding/json/jsontext"
-
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
- "go.mau.fi/util/exgjson"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix/id"
-)
-
-func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value {
- filtered := jsontext.Value("{}")
- var err error
- for _, path := range allowedPaths {
- res := gjson.GetBytes(object, path)
- if res.Exists() {
- var raw jsontext.Value
- if res.Index > 0 {
- raw = object[res.Index : res.Index+len(res.Raw)]
- } else {
- raw = jsontext.Value(res.Raw)
- }
- filtered, err = sjson.SetRawBytes(filtered, path, raw)
- if err != nil {
- panic(err)
- }
- }
- }
- return filtered
-}
-
-func (pdu *PDU) Clone() *PDU {
- return ptr.Clone(pdu)
-}
-
-func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU {
- pdu.Signatures = nil
- return pdu.Redact(roomVersion)
-}
-
-var emptyObject = jsontext.Value("{}")
-
-func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value {
- switch eventType {
- case "m.room.member":
- allowedPaths := []string{"membership"}
- if roomVersion.RestrictedJoinsFix() {
- allowedPaths = append(allowedPaths, "join_authorised_via_users_server")
- }
- if roomVersion.UpdatedRedactionRules() {
- allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed"))
- }
- return filteredObject(content, allowedPaths...)
- case "m.room.create":
- if !roomVersion.UpdatedRedactionRules() {
- return filteredObject(content, "creator")
- }
- return content
- case "m.room.join_rules":
- if roomVersion.RestrictedJoins() {
- return filteredObject(content, "join_rule", "allow")
- }
- return filteredObject(content, "join_rule")
- case "m.room.power_levels":
- allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"}
- if roomVersion.UpdatedRedactionRules() {
- allowedKeys = append(allowedKeys, "invite")
- }
- return filteredObject(content, allowedKeys...)
- case "m.room.history_visibility":
- return filteredObject(content, "history_visibility")
- case "m.room.redaction":
- if roomVersion.RedactsInContent() {
- return filteredObject(content, "redacts")
- }
- return emptyObject
- case "m.room.aliases":
- if roomVersion.SpecialCasedAliasesAuth() {
- return filteredObject(content, "aliases")
- }
- return emptyObject
- default:
- return emptyObject
- }
-}
-
-func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU {
- pdu.Unknown = nil
- pdu.Unsigned = nil
- if roomVersion.UpdatedRedactionRules() {
- pdu.DeprecatedPrevState = nil
- pdu.DeprecatedOrigin = nil
- pdu.DeprecatedMembership = nil
- }
- if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() {
- pdu.Redacts = nil
- }
- pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
- return pdu
-}
diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go
deleted file mode 100644
index 04e7c5ef..00000000
--- a/federation/pdu/signature.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu
-
-import (
- "crypto/ed25519"
- "encoding/base64"
- "fmt"
- "time"
-
- "maunium.net/go/mautrix/federation/signutil"
- "maunium.net/go/mautrix/id"
-)
-
-func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
- err := pdu.FillContentHash()
- if err != nil {
- return err
- }
- rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
- if err != nil {
- return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
- }
- signature := ed25519.Sign(privateKey, rawJSON)
- pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature))
- return nil
-}
-
-func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
- rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
- if err != nil {
- return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
- }
- verified := false
- for keyID, sig := range pdu.Signatures[serverName] {
- originServerTS := time.UnixMilli(pdu.OriginServerTS)
- key, validUntil, err := getKey(serverName, keyID, originServerTS)
- if err != nil {
- return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
- } else if key == "" {
- return fmt.Errorf("key %s not found for %s", keyID, serverName)
- } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() {
- return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS)
- } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
- return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
- } else {
- verified = true
- }
- }
- if !verified {
- return fmt.Errorf("no verifiable signatures found for server %s", serverName)
- }
- return nil
-}
diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go
deleted file mode 100644
index 01df5076..00000000
--- a/federation/pdu/signature_test.go
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu_test
-
-import (
- "crypto/ed25519"
- "encoding/base64"
- "encoding/json/jsontext"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "go.mau.fi/util/exerrors"
-
- "maunium.net/go/mautrix/federation/pdu"
- "maunium.net/go/mautrix/id"
-)
-
-func TestPDU_VerifySignature(t *testing.T) {
- for _, test := range testPDUs {
- t.Run(test.name, func(t *testing.T) {
- parsed := parsePDU(test.pdu)
- err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
- assert.NoError(t, err)
- })
- }
-}
-
-func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) {
- test := roomV12MessageTestPDU
- parsed := parsePDU(test.pdu)
- err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
- return
- })
- assert.Error(t, err)
-}
-
-func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) {
- test := roomV4MessageTestPDU
- parsed := parsePDU(test.pdu)
- err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
- key = test.keys[keyID].key
- validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
- return
- })
- assert.NoError(t, err)
-}
-
-func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) {
- test := roomV12MessageTestPDU
- parsed := parsePDU(test.pdu)
- err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
- key = test.keys[keyID].key
- validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
- return
- })
- assert.Error(t, err)
-}
-
-func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) {
- test := roomV12MessageTestPDU
- parsed := parsePDU(test.pdu)
- for _, sigs := range parsed.Signatures {
- for key := range sigs {
- sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC"
- }
- }
- err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey)
- assert.Error(t, err)
-}
-
-func TestPDU_Sign(t *testing.T) {
- pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil))
- evt := &pdu.PDU{
- AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"},
- Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`),
- Depth: 123,
- OriginServerTS: 1755384351627,
- PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"},
- RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
- Sender: "@tulir:example.com",
- Type: "m.room.message",
- }
- err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey)
- require.NoError(t, err)
- err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) {
- if serverName == "example.com" && keyID == "ed25519:rand" {
- key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey))
- validUntil = time.Now()
- }
- return
- })
- require.NoError(t, err)
-
-}
diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go
deleted file mode 100644
index 9557f8ab..00000000
--- a/federation/pdu/v1.go
+++ /dev/null
@@ -1,277 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu
-
-import (
- "crypto/ed25519"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/base64"
- "encoding/json/jsontext"
- "encoding/json/v2"
- "fmt"
- "time"
-
- "github.com/tidwall/gjson"
- "go.mau.fi/util/ptr"
-
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/federation/signutil"
- "maunium.net/go/mautrix/id"
-)
-
-type V1EventReference struct {
- ID id.EventID
- Hashes Hashes
-}
-
-var (
- _ json.UnmarshalerFrom = (*V1EventReference)(nil)
- _ json.MarshalerTo = (*V1EventReference)(nil)
-)
-
-func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error {
- return json.MarshalEncode(enc, []any{er.ID, er.Hashes})
-}
-
-func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
- var ref V1EventReference
- var data []jsontext.Value
- if err := json.UnmarshalDecode(dec, &data); err != nil {
- return err
- } else if len(data) != 2 {
- return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data))
- } else if err = json.Unmarshal(data[0], &ref.ID); err != nil {
- return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err)
- } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil {
- return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err)
- }
- *er = ref
- return nil
-}
-
-type RoomV1PDU struct {
- AuthEvents []V1EventReference `json:"auth_events"`
- Content jsontext.Value `json:"content"`
- Depth int64 `json:"depth"`
- EventID id.EventID `json:"event_id"`
- Hashes *Hashes `json:"hashes,omitzero"`
- OriginServerTS int64 `json:"origin_server_ts"`
- PrevEvents []V1EventReference `json:"prev_events"`
- Redacts *id.EventID `json:"redacts,omitzero"`
- RoomID id.RoomID `json:"room_id"`
- Sender id.UserID `json:"sender"`
- Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"`
- StateKey *string `json:"state_key,omitzero"`
- Type string `json:"type"`
- Unsigned jsontext.Value `json:"unsigned,omitzero"`
-
- Unknown jsontext.Value `json:",unknown"`
-
- // Deprecated legacy fields
- DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"`
- DeprecatedOrigin jsontext.Value `json:"origin,omitzero"`
- DeprecatedMembership jsontext.Value `json:"membership,omitzero"`
-}
-
-func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) {
- return pdu.RoomID, nil
-}
-
-func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) {
- if !pdu.SupportsRoomVersion(roomVersion) {
- return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion)
- }
- return pdu.EventID, nil
-}
-
-func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU {
- pdu.Signatures = nil
- return pdu.Redact(roomVersion)
-}
-
-func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU {
- pdu.Unknown = nil
- pdu.Unsigned = nil
- if pdu.Type != "m.room.redaction" {
- pdu.Redacts = nil
- }
- pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion)
- return pdu
-}
-
-func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) {
- if !pdu.SupportsRoomVersion(roomVersion) {
- return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion)
- }
- if pdu == nil {
- return [32]byte{}, ErrPDUIsNil
- }
- if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil {
- if err := pdu.FillContentHash(); err != nil {
- return [32]byte{}, err
- }
- }
- rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
- if err != nil {
- return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err)
- }
- return sha256.Sum256(rawJSON), nil
-}
-
-func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) {
- if pdu == nil {
- return [32]byte{}, ErrPDUIsNil
- }
- pduClone := pdu.Clone()
- pduClone.Signatures = nil
- pduClone.Unsigned = nil
- pduClone.Hashes = nil
- rawJSON, err := marshalCanonical(pduClone)
- if err != nil {
- return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err)
- }
- return sha256.Sum256(rawJSON), nil
-}
-
-func (pdu *RoomV1PDU) FillContentHash() error {
- if pdu == nil {
- return ErrPDUIsNil
- } else if pdu.Hashes != nil {
- return nil
- } else if hash, err := pdu.CalculateContentHash(); err != nil {
- return err
- } else {
- pdu.Hashes = &Hashes{SHA256: hash[:]}
- return nil
- }
-}
-
-func (pdu *RoomV1PDU) VerifyContentHash() bool {
- if pdu == nil || pdu.Hashes == nil {
- return false
- }
- calculatedHash, err := pdu.CalculateContentHash()
- if err != nil {
- return false
- }
- return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256)
-}
-
-func (pdu *RoomV1PDU) Clone() *RoomV1PDU {
- return ptr.Clone(pdu)
-}
-
-func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error {
- if !pdu.SupportsRoomVersion(roomVersion) {
- return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion)
- }
- err := pdu.FillContentHash()
- if err != nil {
- return err
- }
- rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
- if err != nil {
- return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
- }
- signature := ed25519.Sign(privateKey, rawJSON)
- if pdu.Signatures == nil {
- pdu.Signatures = make(map[string]map[id.KeyID]string)
- }
- if _, ok := pdu.Signatures[serverName]; !ok {
- pdu.Signatures[serverName] = make(map[id.KeyID]string)
- }
- pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature)
- return nil
-}
-
-func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error {
- if !pdu.SupportsRoomVersion(roomVersion) {
- return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion)
- }
- rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion))
- if err != nil {
- return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err)
- }
- verified := false
- for keyID, sig := range pdu.Signatures[serverName] {
- originServerTS := time.UnixMilli(pdu.OriginServerTS)
- key, _, err := getKey(serverName, keyID, originServerTS)
- if err != nil {
- return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err)
- } else if key == "" {
- return fmt.Errorf("key %s not found for %s", keyID, serverName)
- } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil {
- return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err)
- } else {
- verified = true
- }
- }
- if !verified {
- return fmt.Errorf("no verifiable signatures found for server %s", serverName)
- }
- return nil
-}
-
-func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool {
- switch roomVersion {
- case id.RoomV0, id.RoomV1, id.RoomV2:
- return true
- default:
- return false
- }
-}
-
-func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) {
- if !pdu.SupportsRoomVersion(roomVersion) {
- return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion)
- }
- evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType}
- if pdu.StateKey != nil {
- evtType.Class = event.StateEventType
- }
- evt := &event.Event{
- StateKey: pdu.StateKey,
- Sender: pdu.Sender,
- Type: evtType,
- Timestamp: pdu.OriginServerTS,
- ID: pdu.EventID,
- RoomID: pdu.RoomID,
- Redacts: ptr.Val(pdu.Redacts),
- }
- err := json.Unmarshal(pdu.Content, &evt.Content)
- if err != nil {
- return nil, fmt.Errorf("failed to unmarshal content: %w", err)
- }
- return evt, nil
-}
-
-func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) {
- if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil {
- return AuthEventSelection{}
- }
- keys = make(AuthEventSelection, 0, 3)
- keys.Add(event.StateCreate.Type, "")
- keys.Add(event.StatePowerLevels.Type, "")
- keys.Add(event.StateMember.Type, pdu.Sender.String())
- if pdu.Type == event.StateMember.Type && pdu.StateKey != nil {
- keys.Add(event.StateMember.Type, *pdu.StateKey)
- membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str)
- if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock {
- keys.Add(event.StateJoinRules.Type, "")
- }
- if membership == event.MembershipInvite {
- thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str
- if thirdPartyInviteToken != "" {
- keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken)
- }
- }
- }
- return
-}
diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go
deleted file mode 100644
index ecf2dbd2..00000000
--- a/federation/pdu/v1_test.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-//go:build goexperiment.jsonv2
-
-package pdu_test
-
-import (
- "encoding/base64"
- "encoding/json/v2"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "go.mau.fi/util/exerrors"
-
- "maunium.net/go/mautrix/federation/pdu"
- "maunium.net/go/mautrix/id"
-)
-
-var testV1PDUs = []testPDU{{
- name: "m.room.message in v1 room",
- pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`,
- eventID: "$16738169022163bokdi:maunium.net",
- roomVersion: id.RoomV1,
- serverDetails: mauniumNet,
-}, {
- name: "m.room.create in v1 room",
- pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`,
- eventID: "$143454825711DhCxH:matrix.org",
- roomVersion: id.RoomV1,
- serverDetails: matrixOrg,
-}, {
- name: "m.room.member in v1 room",
- pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`,
- eventID: "$15660585693272iEryv:maunium.net",
- roomVersion: id.RoomV1,
- serverDetails: mauniumNet,
-}}
-
-func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) {
- exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out))
- return
-}
-
-func TestRoomV1PDU_CalculateContentHash(t *testing.T) {
- for _, test := range testV1PDUs {
- t.Run(test.name, func(t *testing.T) {
- parsed := parseV1PDU(test.pdu)
- contentHash := exerrors.Must(parsed.CalculateContentHash())
- assert.Equal(
- t,
- base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256),
- base64.RawStdEncoding.EncodeToString(contentHash[:]),
- )
- })
- }
-}
-
-func TestRoomV1PDU_VerifyContentHash(t *testing.T) {
- for _, test := range testV1PDUs {
- t.Run(test.name, func(t *testing.T) {
- parsed := parseV1PDU(test.pdu)
- assert.True(t, parsed.VerifyContentHash())
- })
- }
-}
-
-func TestRoomV1PDU_VerifySignature(t *testing.T) {
- for _, test := range testV1PDUs {
- t.Run(test.name, func(t *testing.T) {
- parsed := parseV1PDU(test.pdu)
- err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) {
- key, ok := test.keys[keyID]
- if ok {
- return key.key, key.validUntilTS, nil
- }
- return "", time.Time{}, nil
- })
- assert.NoError(t, err)
- })
- }
-}
diff --git a/federation/resolution.go b/federation/resolution.go
index a3188266..69d4d3bf 100644
--- a/federation/resolution.go
+++ b/federation/resolution.go
@@ -20,8 +20,6 @@ import (
"time"
"github.com/rs/zerolog"
-
- "maunium.net/go/mautrix"
)
type ResolvedServerName struct {
@@ -80,10 +78,7 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
- wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
- if ok {
- hostname, port = wkHost, wkPort
- }
+ hostname, port, ok = ParseServerName(wellKnown.Server)
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
@@ -176,11 +171,9 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode)
- } else if resp.ContentLength > mautrix.WellKnownMaxSize {
- return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength)
}
var respData RespWellKnown
- err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData)
+ err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err)
} else if respData.Server == "" {
diff --git a/federation/serverauth.go b/federation/serverauth.go
index cd300341..f46c7991 100644
--- a/federation/serverauth.go
+++ b/federation/serverauth.go
@@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res
}
err = (&signableRequest{
Method: r.Method,
- URI: r.URL.RequestURI(),
+ URI: r.URL.EscapedPath(),
Origin: parsed.Origin,
Destination: destination,
Content: reqBody,
diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go
index f99fc6cf..9fa15459 100644
--- a/federation/serverauth_test.go
+++ b/federation/serverauth_test.go
@@ -19,9 +19,9 @@ import (
func TestServerKeyResponse_VerifySelfSignature(t *testing.T) {
cli := federation.NewClient("", nil, nil)
ctx := context.Background()
- for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} {
+ for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} {
t.Run(name, func(t *testing.T) {
- resp, err := cli.ServerKeys(ctx, name)
+ resp, err := cli.ServerKeys(ctx, "matrix.org")
require.NoError(t, err)
assert.NoError(t, resp.VerifySelfSignature())
})
diff --git a/federation/signingkey.go b/federation/signingkey.go
index a4ad9679..0ae6a571 100644
--- a/federation/signingkey.go
+++ b/federation/signingkey.go
@@ -10,15 +10,17 @@ import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
+ "errors"
"fmt"
"strings"
"time"
+ "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
+ "go.mau.fi/util/exgjson"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/crypto/canonicaljson"
- "maunium.net/go/mautrix/federation/signutil"
"maunium.net/go/mautrix/id"
)
@@ -33,8 +35,8 @@ type SigningKey struct {
//
// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
func (sk *SigningKey) SynapseString() string {
- alg, keyID := sk.ID.Parse()
- return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
+ alg, id := sk.ID.Parse()
+ return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
}
// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
@@ -98,13 +100,56 @@ func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool {
func (skr *ServerKeyResponse) VerifySelfSignature() error {
for keyID, key := range skr.VerifyKeys {
- if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
+ if err := VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil {
return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err)
}
}
return nil
}
+func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
+ var err error
+ message, ok := data.(json.RawMessage)
+ if !ok {
+ message, err = json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ }
+ sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
+ if sigVal.Type != gjson.String {
+ return ErrSignatureNotFound
+ }
+ message, err = sjson.DeleteBytes(message, "signatures")
+ if err != nil {
+ return fmt.Errorf("failed to delete signatures: %w", err)
+ }
+ message, err = sjson.DeleteBytes(message, "unsigned")
+ if err != nil {
+ return fmt.Errorf("failed to delete unsigned: %w", err)
+ }
+ return VerifyJSONRaw(key, sigVal.Str, message)
+}
+
+var ErrSignatureNotFound = errors.New("signature not found")
+var ErrInvalidSignature = errors.New("invalid signature")
+
+func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
+ sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
+ if err != nil {
+ return fmt.Errorf("failed to decode signature: %w", err)
+ }
+ keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
+ if err != nil {
+ return fmt.Errorf("failed to decode key: %w", err)
+ }
+ message = canonicaljson.CanonicalJSONAssumeValid(message)
+ if !ed25519.Verify(keyBytes, message, sigBytes) {
+ return ErrInvalidSignature
+ }
+ return nil
+}
+
type marshalableSKR ServerKeyResponse
func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error {
diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go
deleted file mode 100644
index ea0e7886..00000000
--- a/federation/signutil/verify.go
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package signutil
-
-import (
- "crypto/ed25519"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
-
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
- "go.mau.fi/util/exgjson"
-
- "maunium.net/go/mautrix/crypto/canonicaljson"
- "maunium.net/go/mautrix/id"
-)
-
-var ErrSignatureNotFound = errors.New("signature not found")
-var ErrInvalidSignature = errors.New("invalid signature")
-
-func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error {
- var err error
- message, ok := data.(json.RawMessage)
- if !ok {
- message, err = json.Marshal(data)
- if err != nil {
- return fmt.Errorf("failed to marshal data: %w", err)
- }
- }
- sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID)))
- if sigVal.Type != gjson.String {
- return ErrSignatureNotFound
- }
- message, err = sjson.DeleteBytes(message, "signatures")
- if err != nil {
- return fmt.Errorf("failed to delete signatures: %w", err)
- }
- message, err = sjson.DeleteBytes(message, "unsigned")
- if err != nil {
- return fmt.Errorf("failed to delete unsigned: %w", err)
- }
- return VerifyJSONRaw(key, sigVal.Str, message)
-}
-
-func VerifyJSONAny(key id.SigningKey, data any) error {
- var err error
- message, ok := data.(json.RawMessage)
- if !ok {
- message, err = json.Marshal(data)
- if err != nil {
- return fmt.Errorf("failed to marshal data: %w", err)
- }
- }
- sigs := gjson.GetBytes(message, "signatures")
- if !sigs.IsObject() {
- return ErrSignatureNotFound
- }
- message, err = sjson.DeleteBytes(message, "signatures")
- if err != nil {
- return fmt.Errorf("failed to delete signatures: %w", err)
- }
- message, err = sjson.DeleteBytes(message, "unsigned")
- if err != nil {
- return fmt.Errorf("failed to delete unsigned: %w", err)
- }
- var validated bool
- sigs.ForEach(func(_, value gjson.Result) bool {
- if !value.IsObject() {
- return true
- }
- value.ForEach(func(_, value gjson.Result) bool {
- if value.Type != gjson.String {
- return true
- }
- validated = VerifyJSONRaw(key, value.Str, message) == nil
- return !validated
- })
- return !validated
- })
- if !validated {
- return ErrInvalidSignature
- }
- return nil
-}
-
-func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error {
- sigBytes, err := base64.RawStdEncoding.DecodeString(sig)
- if err != nil {
- return fmt.Errorf("failed to decode signature: %w", err)
- }
- keyBytes, err := base64.RawStdEncoding.DecodeString(string(key))
- if err != nil {
- return fmt.Errorf("failed to decode key: %w", err)
- }
- message = canonicaljson.CanonicalJSONAssumeValid(message)
- if !ed25519.Verify(keyBytes, message, sigBytes) {
- return ErrInvalidSignature
- }
- return nil
-}
diff --git a/filter.go b/filter.go
index 54973dab..c6c8211b 100644
--- a/filter.go
+++ b/filter.go
@@ -57,7 +57,7 @@ type FilterPart struct {
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
- return errors.New("bad event_format value")
+ return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
}
return nil
}
diff --git a/format/htmlparser.go b/format/htmlparser.go
index e0507d93..f9d51e39 100644
--- a/format/htmlparser.go
+++ b/format/htmlparser.go
@@ -13,7 +13,6 @@ import (
"strconv"
"strings"
- "go.mau.fi/util/exstrings"
"golang.org/x/net/html"
"maunium.net/go/mautrix/event"
@@ -93,30 +92,6 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string
}
}
-func onlyBacktickCount(line string) (count int) {
- for i := 0; i < len(line); i++ {
- if line[i] != '`' {
- return -1
- }
- count++
- }
- return
-}
-
-func DefaultMonospaceBlockConverter(code, language string, ctx Context) string {
- if len(code) == 0 || code[len(code)-1] != '\n' {
- code += "\n"
- }
- fence := "```"
- for line := range strings.SplitSeq(code, "\n") {
- count := onlyBacktickCount(strings.TrimSpace(line))
- if count >= len(fence) {
- fence = strings.Repeat("`", count+1)
- }
- }
- return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence)
-}
-
// HTMLParser is a somewhat customizable Matrix HTML parser.
type HTMLParser struct {
PillConverter PillConverter
@@ -311,10 +286,7 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
}
if parser.LinkConverter != nil {
return parser.LinkConverter(str, href, ctx)
- } else if str == href ||
- str == strings.TrimPrefix(href, "mailto:") ||
- str == strings.TrimPrefix(href, "http://") ||
- str == strings.TrimPrefix(href, "https://") {
+ } else if str == href {
return str
}
return fmt.Sprintf("%s (%s)", str, href)
@@ -372,7 +344,10 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
}
- return DefaultMonospaceBlockConverter(preStr, language, ctx)
+ if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' {
+ preStr += "\n"
+ }
+ return fmt.Sprintf("```%s\n%s```", language, preStr)
default:
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
@@ -393,7 +368,7 @@ func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) Tagge
switch node.Type {
case html.TextNode:
if !ctx.PreserveWhitespace {
- node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", ""))
+ node.Data = strings.Replace(node.Data, "\n", "", -1)
}
if parser.TextConverter != nil {
node.Data = parser.TextConverter(node.Data, ctx)
diff --git a/format/markdown.go b/format/markdown.go
index 77ced0dc..3d9979b4 100644
--- a/format/markdown.go
+++ b/format/markdown.go
@@ -57,18 +57,7 @@ type uriAble interface {
}
func MarkdownMention(id uriAble) string {
- return MarkdownMentionWithName(id.String(), id)
-}
-
-func MarkdownMentionWithName(name string, id uriAble) string {
- return MarkdownLink(name, id.URI().MatrixToURL())
-}
-
-func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string {
- if name == "" {
- name = id.String()
- }
- return MarkdownLink(name, id.URI(via...).MatrixToURL())
+ return MarkdownLink(id.String(), id.URI().MatrixToURL())
}
func MarkdownLink(name string, url string) string {
diff --git a/go.mod b/go.mod
index 49a1d4e4..59f29c0c 100644
--- a/go.mod
+++ b/go.mod
@@ -1,42 +1,43 @@
module maunium.net/go/mautrix
-go 1.25.0
+go 1.23.0
-toolchain go1.26.0
+toolchain go1.24.4
require (
- filippo.io/edwards25519 v1.2.0
+ filippo.io/edwards25519 v1.1.0
github.com/chzyer/readline v1.5.1
- github.com/coder/websocket v1.8.14
- github.com/lib/pq v1.11.2
- github.com/mattn/go-sqlite3 v1.14.34
+ github.com/gorilla/mux v1.8.0
+ github.com/gorilla/websocket v1.5.0
+ github.com/lib/pq v1.10.9
+ github.com/mattn/go-sqlite3 v1.14.28
github.com/rs/xid v1.6.0
github.com/rs/zerolog v1.34.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
- github.com/stretchr/testify v1.11.1
+ github.com/stretchr/testify v1.10.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
- github.com/yuin/goldmark v1.7.16
- go.mau.fi/util v0.9.6
- go.mau.fi/zeroconfig v0.2.0
- golang.org/x/crypto v0.48.0
- golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
- golang.org/x/net v0.50.0
- golang.org/x/sync v0.19.0
+ github.com/yuin/goldmark v1.7.12
+ go.mau.fi/util v0.8.8
+ go.mau.fi/zeroconfig v0.1.3
+ golang.org/x/crypto v0.40.0
+ golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc
+ golang.org/x/net v0.42.0
+ golang.org/x/sync v0.16.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
)
require (
- github.com/coreos/go-systemd/v22 v22.6.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
- github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
+ github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
- golang.org/x/sys v0.41.0 // indirect
- golang.org/x/text v0.34.0 // indirect
+ golang.org/x/sys v0.34.0 // indirect
+ golang.org/x/text v0.27.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)
diff --git a/go.sum b/go.sum
index 871a5156..9f48386e 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
-filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
+filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
+filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
@@ -8,16 +8,17 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
-github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
-github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
+github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
-github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
-github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
-github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
+github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
+github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
+github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -25,10 +26,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
-github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
-github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
+github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
+github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4=
+github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -38,8 +39,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
-github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -50,28 +51,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
-github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
-go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
-go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
-go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU=
-go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w=
-golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
-golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
-golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
-golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
-golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
-golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
-golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
-golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY=
+github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
+go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c=
+go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c=
+go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM=
+go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
+golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
+golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
+golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
+golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
+golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
+golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
+golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
+golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
-golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
-golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
+golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
+golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
+golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
diff --git a/id/contenturi.go b/id/contenturi.go
index 67127b6c..e6a313f5 100644
--- a/id/contenturi.go
+++ b/id/contenturi.go
@@ -17,14 +17,8 @@ import (
)
var (
- ErrInvalidContentURI = errors.New("invalid Matrix content URI")
- ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
-)
-
-// Deprecated: use variables prefixed with Err
-var (
- InvalidContentURI = ErrInvalidContentURI
- InputNotJSONString = ErrInputNotJSONString
+ InvalidContentURI = errors.New("invalid Matrix content URI")
+ InputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@@ -61,9 +55,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
- err = ErrInvalidContentURI
+ err = InvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = ErrInvalidContentURI
+ err = InvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@@ -77,9 +71,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
- err = ErrInvalidContentURI
+ err = InvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = ErrInvalidContentURI
+ err = InvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@@ -92,7 +86,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
- return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString)
+ return InputNotJSONString
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {
diff --git a/id/crypto.go b/id/crypto.go
index ee857f78..355a84a8 100644
--- a/id/crypto.go
+++ b/id/crypto.go
@@ -53,34 +53,6 @@ const (
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
)
-type KeySource string
-
-func (source KeySource) String() string {
- return string(source)
-}
-
-func (source KeySource) Int() int {
- switch source {
- case KeySourceDirect:
- return 100
- case KeySourceBackup:
- return 90
- case KeySourceImport:
- return 80
- case KeySourceForward:
- return 50
- default:
- return 0
- }
-}
-
-const (
- KeySourceDirect KeySource = "direct"
- KeySourceBackup KeySource = "backup"
- KeySourceImport KeySource = "import"
- KeySourceForward KeySource = "forward"
-)
-
// BackupVersion is an arbitrary string that identifies a server side key backup.
type KeyBackupVersion string
diff --git a/id/matrixuri.go b/id/matrixuri.go
index d5c78bc7..8f5ec849 100644
--- a/id/matrixuri.go
+++ b/id/matrixuri.go
@@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
- if len(uri.Via) > 0 {
+ if uri.Via != nil && len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {
diff --git a/id/opaque.go b/id/opaque.go
index c1ad4988..1d9f0dcf 100644
--- a/id/opaque.go
+++ b/id/opaque.go
@@ -32,9 +32,6 @@ type EventID string
// https://github.com/matrix-org/matrix-doc/pull/2716
type BatchID string
-// A DelayID is a string identifying a delayed event.
-type DelayID string
-
func (roomID RoomID) String() string {
return string(roomID)
}
diff --git a/id/roomversion.go b/id/roomversion.go
deleted file mode 100644
index 578c10bd..00000000
--- a/id/roomversion.go
+++ /dev/null
@@ -1,265 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package id
-
-import (
- "errors"
- "fmt"
- "slices"
-)
-
-type RoomVersion string
-
-const (
- RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1
- RoomV1 RoomVersion = "1"
- RoomV2 RoomVersion = "2"
- RoomV3 RoomVersion = "3"
- RoomV4 RoomVersion = "4"
- RoomV5 RoomVersion = "5"
- RoomV6 RoomVersion = "6"
- RoomV7 RoomVersion = "7"
- RoomV8 RoomVersion = "8"
- RoomV9 RoomVersion = "9"
- RoomV10 RoomVersion = "10"
- RoomV11 RoomVersion = "11"
- RoomV12 RoomVersion = "12"
-)
-
-func (rv RoomVersion) Equals(versions ...RoomVersion) bool {
- return slices.Contains(versions, rv)
-}
-
-func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool {
- return !rv.Equals(versions...)
-}
-
-var ErrUnknownRoomVersion = errors.New("unknown room version")
-
-func (rv RoomVersion) unknownVersionError() error {
- return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv)
-}
-
-func (rv RoomVersion) IsKnown() bool {
- switch rv {
- case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12:
- return true
- default:
- return false
- }
-}
-
-type StateResVersion int
-
-const (
- // StateResV1 is the original state resolution algorithm.
- StateResV1 StateResVersion = 0
- // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759
- StateResV2 StateResVersion = 1
- // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297
- StateResV2_1 StateResVersion = 2
-)
-
-// StateResVersion returns the version of the state resolution algorithm used by this room version.
-func (rv RoomVersion) StateResVersion() StateResVersion {
- switch rv {
- case RoomV0, RoomV1:
- return StateResV1
- case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11:
- return StateResV2
- case RoomV12:
- return StateResV2_1
- default:
- panic(rv.unknownVersionError())
- }
-}
-
-type EventIDFormat int
-
-const (
- // EventIDFormatCustom is the original format used by room v1 and v2.
- // Event IDs in this format are an arbitrary string followed by a colon and the server name.
- EventIDFormatCustom EventIDFormat = 0
- // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659.
- // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event.
- EventIDFormatBase64 EventIDFormat = 1
- // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002.
- // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event.
- EventIDFormatURLSafeBase64 EventIDFormat = 2
-)
-
-// EventIDFormat returns the format of event IDs used by this room version.
-func (rv RoomVersion) EventIDFormat() EventIDFormat {
- switch rv {
- case RoomV0, RoomV1, RoomV2:
- return EventIDFormatCustom
- case RoomV3:
- return EventIDFormatBase64
- default:
- return EventIDFormatURLSafeBase64
- }
-}
-
-/////////////////////
-// Room v5 changes //
-/////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/2077
-
-// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys
-// must be enforced on received events.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076
-func (rv RoomVersion) EnforceSigningKeyValidity() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4)
-}
-
-/////////////////////
-// Room v6 changes //
-/////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/2240
-
-// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased
-// to only always allow servers to modify the state event with their own server name as state key.
-// This also implies that the `aliases` field is protected from redactions.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432
-func (rv RoomVersion) SpecialCasedAliasesAuth() bool {
- return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
-}
-
-// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540
-func (rv RoomVersion) ForbidFloatsAndBigInts() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
-}
-
-// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth.
-// However, the field is not protected from redactions.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209
-func (rv RoomVersion) NotificationsPowerLevels() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5)
-}
-
-/////////////////////
-// Room v7 changes //
-/////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/2998
-
-// Knocks returns true if the `knock` join rule is supported.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403
-func (rv RoomVersion) Knocks() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6)
-}
-
-/////////////////////
-// Room v8 changes //
-/////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/3289
-
-// RestrictedJoins returns true if the `restricted` join rule is supported.
-// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083
-func (rv RoomVersion) RestrictedJoins() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7)
-}
-
-/////////////////////
-// Room v9 changes //
-/////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/3375
-
-// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375
-func (rv RoomVersion) RestrictedJoinsFix() bool {
- return rv.RestrictedJoins() && rv != RoomV8
-}
-
-//////////////////////
-// Room v10 changes //
-//////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/3604
-
-// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings).
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667
-func (rv RoomVersion) ValidatePowerLevelInts() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
-}
-
-// KnockRestricted returns true if the `knock_restricted` join rule is supported.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787
-func (rv RoomVersion) KnockRestricted() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9)
-}
-
-//////////////////////
-// Room v11 changes //
-//////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/3820
-
-// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175
-func (rv RoomVersion) CreatorInContent() bool {
- return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
-}
-
-// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level.
-// The redaction protection is also moved from the top level to the content field.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174
-// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection).
-func (rv RoomVersion) RedactsInContent() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
-}
-
-// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied.
-//
-// Specifically:
-//
-// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected.
-// * the entire content of `m.room.create` is protected.
-// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field.
-// * the `m.room.power_levels` event protects the `invite` field in content.
-// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176,
-// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and
-// https://github.com/matrix-org/matrix-spec-proposals/pull/3989
-func (rv RoomVersion) UpdatedRedactionRules() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10)
-}
-
-//////////////////////
-// Room v12 changes //
-//////////////////////
-// https://github.com/matrix-org/matrix-spec-proposals/pull/4304
-
-// Return value of StateResVersion was changed to StateResV2_1
-
-// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level.
-// This also implies that the `m.room.create` event has an `additional_creators` field,
-// and that the creators can't be present in the `m.room.power_levels` event.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289
-func (rv RoomVersion) PrivilegedRoomCreators() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
-}
-
-// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event.
-// This also implies that `m.room.create` events do not have a `room_id` field.
-//
-// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291
-func (rv RoomVersion) RoomIDIsCreateEventID() bool {
- return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11)
-}
diff --git a/id/trust.go b/id/trust.go
index 6255093e..04f6e36b 100644
--- a/id/trust.go
+++ b/id/trust.go
@@ -16,7 +16,6 @@ type TrustState int
const (
TrustStateBlacklisted TrustState = -100
- TrustStateDeviceKeyMismatch TrustState = -5
TrustStateUnset TrustState = 0
TrustStateUnknownDevice TrustState = 10
TrustStateForwarded TrustState = 20
@@ -24,7 +23,7 @@ const (
TrustStateCrossSignedTOFU TrustState = 100
TrustStateCrossSignedVerified TrustState = 200
TrustStateVerified TrustState = 300
- TrustStateInvalid TrustState = -2147483647
+ TrustStateInvalid TrustState = (1 << 31) - 1
)
func (ts *TrustState) UnmarshalText(data []byte) error {
@@ -45,8 +44,6 @@ func ParseTrustState(val string) TrustState {
switch strings.ToLower(val) {
case "blacklisted":
return TrustStateBlacklisted
- case "device-key-mismatch":
- return TrustStateDeviceKeyMismatch
case "unverified":
return TrustStateUnset
case "cross-signed-untrusted":
@@ -70,8 +67,6 @@ func (ts TrustState) String() string {
switch ts {
case TrustStateBlacklisted:
return "blacklisted"
- case TrustStateDeviceKeyMismatch:
- return "device-key-mismatch"
case TrustStateUnset:
return "unverified"
case TrustStateCrossSignedUntrusted:
diff --git a/id/userid.go b/id/userid.go
index 726a0d58..6d9f4080 100644
--- a/id/userid.go
+++ b/id/userid.go
@@ -104,24 +104,16 @@ func ValidateUserLocalpart(localpart string) error {
return nil
}
-// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts.
-// This should be used with care: there are real users still using historical localparts.
-func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.ParseAndValidateRelaxed()
+// ParseAndValidate parses the user ID into the localpart and server name like Parse,
+// and also validates that the localpart is allowed according to the user identifiers spec.
+func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) {
+ localpart, homeserver, err = userID.Parse()
if err == nil {
err = ValidateUserLocalpart(localpart)
}
- return
-}
-
-// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse,
-// and also validates that the user ID is not too long and that the server name is valid.
-func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) {
- if len(userID) > UserIDMaxLength {
+ if err == nil && len(userID) > UserIDMaxLength {
err = ErrUserIDTooLong
- return
}
- localpart, homeserver, err = userID.Parse()
if err == nil && !ValidateServerName(homeserver) {
err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart)
}
@@ -129,7 +121,7 @@ func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, er
}
func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) {
- localpart, homeserver, err = userID.ParseAndValidateStrict()
+ localpart, homeserver, err = userID.ParseAndValidate()
if err == nil {
localpart, err = DecodeUserLocalpart(localpart)
}
@@ -219,15 +211,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
- return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
+ return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
- return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
+ return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
- return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
+ return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@@ -237,7 +229,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
- return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
+ return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])
diff --git a/id/userid_test.go b/id/userid_test.go
index 57a88066..359bc687 100644
--- a/id/userid_test.go
+++ b/id/userid_test.go
@@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) {
assert.True(t, errors.Is(err, id.ErrInvalidUserID))
}
-func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) {
+func TestUserID_ParseAndValidate_Invalid(t *testing.T) {
const inputUserID = "@s p a c e:maunium.net"
- _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
+ _, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart))
}
-func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) {
+func TestUserID_ParseAndValidate_Empty(t *testing.T) {
const inputUserID = "@:ponies.im"
- _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
+ _, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrEmptyLocalpart))
}
-func TestUserID_ParseAndValidateStrict_Long(t *testing.T) {
+func TestUserID_ParseAndValidate_Long(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
+ _, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.Error(t, err)
assert.True(t, errors.Is(err, id.ErrUserIDTooLong))
}
-func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) {
+func TestUserID_ParseAndValidate_NotLong(t *testing.T) {
const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com"
- _, _, err := id.UserID(inputUserID).ParseAndValidateStrict()
+ _, _, err := id.UserID(inputUserID).ParseAndValidate()
assert.NoError(t, err)
}
@@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) {
const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8"
const inputServerName = "example.com"
userID := id.NewEncodedUserID(inputLocalpart, inputServerName)
- parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict()
+ parsedLocalpart, parsedServerName, err := userID.ParseAndValidate()
assert.NoError(t, err)
assert.Equal(t, encodedLocalpart, parsedLocalpart)
assert.Equal(t, inputServerName, parsedServerName)
diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go
index 4d2bc7cf..4be799d3 100644
--- a/mediaproxy/mediaproxy.go
+++ b/mediaproxy/mediaproxy.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -21,12 +21,8 @@ import (
"strings"
"time"
+ "github.com/gorilla/mux"
"github.com/rs/zerolog"
- "github.com/rs/zerolog/hlog"
- "go.mau.fi/util/exerrors"
- "go.mau.fi/util/exhttp"
- "go.mau.fi/util/ptr"
- "go.mau.fi/util/requestlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/federation"
@@ -95,13 +91,9 @@ func (d *GetMediaResponseCallback) GetContentType() string {
return d.ContentType
}
-type FileMeta struct {
- ContentType string
- ReplacementFile string
-}
-
type GetMediaResponseFile struct {
- Callback func(w *os.File) (*FileMeta, error)
+ Callback func(w *os.File) error
+ ContentType string
}
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
@@ -116,8 +108,8 @@ type MediaProxy struct {
serverName string
serverKey *federation.SigningKey
- FederationRouter *http.ServeMux
- ClientMediaRouter *http.ServeMux
+ FederationRouter *mux.Router
+ ClientMediaRouter *mux.Router
}
func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) {
@@ -125,7 +117,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
if err != nil {
return nil, err
}
- mp := &MediaProxy{
+ return &MediaProxy{
serverName: serverName,
serverKey: parsed,
GetMedia: getMedia,
@@ -140,21 +132,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"),
},
},
- }
- mp.FederationRouter = http.NewServeMux()
- mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
- mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation)
- mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
- mp.ClientMediaRouter = http.NewServeMux()
- mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
- mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia)
- mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia)
- mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported)
- mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported)
- mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported)
- mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported)
- mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported)
- return mp, nil
+ }, nil
}
type BasicConfig struct {
@@ -184,8 +162,8 @@ type ServerConfig struct {
}
func (mp *MediaProxy) Listen(cfg ServerConfig) error {
- router := http.NewServeMux()
- mp.RegisterRoutes(router, zerolog.Nop())
+ router := mux.NewRouter()
+ mp.RegisterRoutes(router)
return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
}
@@ -210,29 +188,39 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder
})
}
-func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) {
- errorBodies := exhttp.ErrorBodies{
- NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()),
- MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()),
+func (mp *MediaProxy) RegisterRoutes(router *mux.Router) {
+ if mp.FederationRouter == nil {
+ mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter()
}
- router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware(
- mp.FederationRouter,
- exhttp.StripPrefix("/_matrix/federation"),
- hlog.NewHandler(log),
- hlog.RequestIDHandler("request_id", "Request-Id"),
- requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
- exhttp.HandleErrors(errorBodies),
- ))
- router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware(
- mp.ClientMediaRouter,
- exhttp.StripPrefix("/_matrix/client/v1/media"),
- hlog.NewHandler(log),
- hlog.RequestIDHandler("request_id", "Request-Id"),
- exhttp.CORSMiddleware,
- requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}),
- exhttp.HandleErrors(errorBodies),
- ))
- mp.KeyServer.Register(router, log)
+ if mp.ClientMediaRouter == nil {
+ mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter()
+ }
+
+ mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet)
+ mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet)
+ mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
+ mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet)
+ mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet)
+ mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut)
+ mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost)
+ mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost)
+ mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet)
+ mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet)
+ mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
+ mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
+ mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint)
+ mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod)
+ corsMiddleware := func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
+ w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';")
+ next.ServeHTTP(w, r)
+ })
+ }
+ mp.ClientMediaRouter.Use(corsMiddleware)
+ mp.KeyServer.Register(router)
}
var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax")
@@ -246,7 +234,7 @@ func queryToMap(vals url.Values) map[string]string {
}
func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse {
- mediaID := r.PathValue("mediaID")
+ mediaID := mux.Vars(r)["mediaID"]
if !id.IsValidMediaID(mediaID) {
mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w)
return nil
@@ -392,7 +380,8 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin
func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := zerolog.Ctx(ctx)
- if r.PathValue("serverName") != mp.serverName {
+ vars := mux.Vars(r)
+ if vars["serverName"] != mp.serverName {
mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w)
return
}
@@ -415,7 +404,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTemporaryRedirect)
} else if fileResp, ok := resp.(*GetMediaResponseFile); ok {
responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error {
- mp.addHeaders(w, mimeType, r.PathValue("fileName"))
+ mp.addHeaders(w, mimeType, vars["fileName"])
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
w.WriteHeader(http.StatusOK)
_, err := wt.WriteTo(w)
@@ -436,7 +425,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) {
if dataResp, ok := writerResp.(*GetMediaResponseData); ok {
defer dataResp.Reader.Close()
}
- mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName"))
+ mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"])
if writerResp.GetContentLength() != 0 {
w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10))
}
@@ -458,35 +447,23 @@ func doTempFileDownload(
if err != nil {
return false, fmt.Errorf("failed to create temp file: %w", err)
}
- origTempFile := tempFile
defer func() {
- _ = origTempFile.Close()
- _ = os.Remove(origTempFile.Name())
+ _ = tempFile.Close()
+ _ = os.Remove(tempFile.Name())
}()
- meta, err := data.Callback(tempFile)
+ err = data.Callback(tempFile)
if err != nil {
return false, err
}
- if meta.ReplacementFile != "" {
- tempFile, err = os.Open(meta.ReplacementFile)
- if err != nil {
- return false, fmt.Errorf("failed to open replacement file: %w", err)
- }
- defer func() {
- _ = tempFile.Close()
- _ = os.Remove(origTempFile.Name())
- }()
- } else {
- _, err = tempFile.Seek(0, io.SeekStart)
- if err != nil {
- return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
- }
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
}
fileInfo, err := tempFile.Stat()
if err != nil {
return false, fmt.Errorf("failed to stat temp file: %w", err)
}
- mimeType := meta.ContentType
+ mimeType := data.ContentType
if mimeType == "" {
buf := make([]byte, 512)
n, err := tempFile.Read(buf)
@@ -514,6 +491,11 @@ var (
ErrPreviewURLNotSupported = mautrix.MUnrecognized.
WithMessage("This is a media proxy and does not support URL previews.").
WithStatus(http.StatusNotImplemented)
+ ErrUnknownEndpoint = mautrix.MUnrecognized.
+ WithMessage("Unrecognized endpoint")
+ ErrUnsupportedMethod = mautrix.MUnrecognized.
+ WithMessage("Invalid method for endpoint").
+ WithStatus(http.StatusMethodNotAllowed)
)
func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) {
@@ -523,3 +505,11 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request)
func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) {
ErrPreviewURLNotSupported.Write(w)
}
+
+func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) {
+ ErrUnknownEndpoint.Write(w)
+}
+
+func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) {
+ ErrUnsupportedMethod.Write(w)
+}
diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go
deleted file mode 100644
index 507c24a5..00000000
--- a/mockserver/mockserver.go
+++ /dev/null
@@ -1,307 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package mockserver
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "maps"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log
- "github.com/stretchr/testify/require"
- "go.mau.fi/util/dbutil"
- "go.mau.fi/util/exerrors"
- "go.mau.fi/util/exhttp"
- "go.mau.fi/util/random"
-
- "maunium.net/go/mautrix"
- "maunium.net/go/mautrix/crypto"
- "maunium.net/go/mautrix/crypto/cryptohelper"
- "maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/id"
-)
-
-func mustDecode(r *http.Request, data any) {
- exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data))
-}
-
-type userAndDeviceID struct {
- UserID id.UserID
- DeviceID id.DeviceID
-}
-
-type MockServer struct {
- Router *http.ServeMux
- Server *httptest.Server
-
- AccessTokenToUserID map[string]userAndDeviceID
- DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event
- AccountData map[id.UserID]map[event.Type]json.RawMessage
- DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys
- OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey
- MasterKeys map[id.UserID]mautrix.CrossSigningKeys
- SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys
- UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys
-
- PopOTKs bool
- MemoryStore bool
-}
-
-func Create(t testing.TB) *MockServer {
- t.Helper()
-
- server := MockServer{
- AccessTokenToUserID: map[string]userAndDeviceID{},
- DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{},
- AccountData: map[id.UserID]map[event.Type]json.RawMessage{},
- DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
- OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
- MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- PopOTKs: true,
- MemoryStore: true,
- }
-
- router := http.NewServeMux()
- router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin)
- router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery)
- router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim)
- router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice)
- router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData)
- router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload)
- router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp)
- router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload)
- server.Router = router
- server.Server = httptest.NewServer(router)
- t.Cleanup(server.Server.Close)
- return &server
-}
-
-func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID {
- authHeader := r.Header.Get("Authorization")
- authHeader = strings.TrimPrefix(authHeader, "Bearer ")
- userID, ok := ms.AccessTokenToUserID[authHeader]
- if !ok {
- panic("no user ID found for access token " + authHeader)
- }
- return userID
-}
-
-func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) {
- exhttp.WriteEmptyJSONResponse(w, http.StatusOK)
-}
-
-func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) {
- var loginReq mautrix.ReqLogin
- mustDecode(r, &loginReq)
-
- deviceID := loginReq.DeviceID
- if deviceID == "" {
- deviceID = id.DeviceID(random.String(10))
- }
-
- accessToken := random.String(30)
- userID := id.UserID(loginReq.Identifier.User)
- ms.AccessTokenToUserID[accessToken] = userAndDeviceID{
- UserID: userID,
- DeviceID: deviceID,
- }
-
- exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{
- AccessToken: accessToken,
- DeviceID: deviceID,
- UserID: userID,
- })
-}
-
-func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) {
- var req mautrix.ReqSendToDevice
- mustDecode(r, &req)
- evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType}
-
- for user, devices := range req.Messages {
- for device, content := range devices {
- if _, ok := ms.DeviceInbox[user]; !ok {
- ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{}
- }
- content.ParseRaw(evtType)
- ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{
- Sender: ms.getUserID(r).UserID,
- Type: evtType,
- Content: *content,
- })
- }
- }
- ms.emptyResp(w, r)
-}
-
-func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) {
- userID := id.UserID(r.PathValue("userID"))
- eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType}
-
- jsonData, _ := io.ReadAll(r.Body)
- if _, ok := ms.AccountData[userID]; !ok {
- ms.AccountData[userID] = map[event.Type]json.RawMessage{}
- }
- ms.AccountData[userID][eventType] = json.RawMessage(jsonData)
- ms.emptyResp(w, r)
-}
-
-func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) {
- var req mautrix.ReqQueryKeys
- mustDecode(r, &req)
- resp := mautrix.RespQueryKeys{
- MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{},
- DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{},
- }
- for user := range req.DeviceKeys {
- resp.MasterKeys[user] = ms.MasterKeys[user]
- resp.UserSigningKeys[user] = ms.UserSigningKeys[user]
- resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user]
- resp.DeviceKeys[user] = ms.DeviceKeys[user]
- }
- exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
-}
-
-func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) {
- var req mautrix.ReqClaimKeys
- mustDecode(r, &req)
- resp := mautrix.RespClaimKeys{
- OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{},
- }
- for user, devices := range req.OneTimeKeys {
- resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
- for device := range devices {
- keys := ms.OneTimeKeys[user][device]
- for keyID, key := range keys {
- if ms.PopOTKs {
- delete(keys, keyID)
- }
- resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{
- keyID: key,
- }
- break
- }
- }
- }
- exhttp.WriteJSONResponse(w, http.StatusOK, &resp)
-}
-
-func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
- var req mautrix.ReqUploadKeys
- mustDecode(r, &req)
-
- uid := ms.getUserID(r)
- userID := uid.UserID
- if _, ok := ms.DeviceKeys[userID]; !ok {
- ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{}
- }
- if _, ok := ms.OneTimeKeys[userID]; !ok {
- ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}
- }
-
- if req.DeviceKeys != nil {
- ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys
- }
- otks, ok := ms.OneTimeKeys[userID][uid.DeviceID]
- if !ok {
- otks = map[id.KeyID]mautrix.OneTimeKey{}
- ms.OneTimeKeys[userID][uid.DeviceID] = otks
- }
- if req.OneTimeKeys != nil {
- maps.Copy(otks, req.OneTimeKeys)
- }
-
- exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{
- OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)},
- })
-}
-
-func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
- var req mautrix.UploadCrossSigningKeysReq[any]
- mustDecode(r, &req)
-
- userID := ms.getUserID(r).UserID
- ms.MasterKeys[userID] = req.Master
- ms.SelfSigningKeys[userID] = req.SelfSigning
- ms.UserSigningKeys[userID] = req.UserSigning
-
- ms.emptyResp(w, r)
-}
-
-func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) {
- t.Helper()
- if ctx == nil {
- ctx = context.TODO()
- }
- client, err := mautrix.NewClient(ms.Server.URL, "", "")
- require.NoError(t, err)
- client.Client = ms.Server.Client()
-
- _, err = client.Login(ctx, &mautrix.ReqLogin{
- Type: mautrix.AuthTypePassword,
- Identifier: mautrix.UserIdentifier{
- Type: mautrix.IdentifierTypeUser,
- User: userID.String(),
- },
- DeviceID: deviceID,
- Password: "password",
- StoreCredentials: true,
- })
- require.NoError(t, err)
-
- var store any
- if ms.MemoryStore {
- store = crypto.NewMemoryStore(nil)
- client.StateStore = mautrix.NewMemoryStateStore()
- } else {
- store, err = dbutil.NewFromConfig("", dbutil.Config{
- PoolConfig: dbutil.PoolConfig{
- Type: "sqlite3-fk-wal",
- URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)),
- MaxOpenConns: 5,
- MaxIdleConns: 1,
- },
- }, nil)
- require.NoError(t, err)
- }
- cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store)
- require.NoError(t, err)
- client.Crypto = cryptoHelper
-
- err = cryptoHelper.Init(ctx)
- require.NoError(t, err)
-
- machineLog := globallog.Logger.With().
- Stringer("my_user_id", userID).
- Stringer("my_device_id", deviceID).
- Logger()
- cryptoHelper.Machine().Log = &machineLog
-
- err = cryptoHelper.Machine().ShareKeys(ctx, 50)
- require.NoError(t, err)
-
- return client, cryptoHelper.Machine().CryptoStore
-}
-
-func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) {
- t.Helper()
-
- for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] {
- client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt)
- ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:]
- }
-}
diff --git a/pushrules/action.go b/pushrules/action.go
index b5a884b2..9838e88b 100644
--- a/pushrules/action.go
+++ b/pushrules/action.go
@@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
- action.Value = val["value"]
+ action.Value, _ = val["value"]
}
}
return nil
diff --git a/pushrules/action_test.go b/pushrules/action_test.go
index 3c0aa168..a8f68415 100644
--- a/pushrules/action_test.go
+++ b/pushrules/action_test.go
@@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`))
- assert.NoError(t, err)
+ assert.Nil(t, err)
err = pa.UnmarshalJSON([]byte(`9001`))
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`"foo"`))
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, pushrules.PushActionType("foo"), pa.Action)
assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak)
@@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) {
}
err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`))
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, pushrules.ActionSetTweak, pa.Action)
assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak)
@@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data)
}
@@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) {
Value: "bar",
}
data, err := pa.MarshalJSON()
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.Equal(t, []byte(`"something else"`), data)
}
diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go
index 37af3e34..0d3eaf7a 100644
--- a/pushrules/condition_test.go
+++ b/pushrules/condition_test.go
@@ -102,6 +102,14 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
+func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
+ return &pushrules.PushCondition{
+ Kind: pushrules.KindEventPropertyContains,
+ Key: key,
+ Value: value,
+ }
+}
+
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),
diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go
index a5a0f5e7..a531ca28 100644
--- a/pushrules/pushrules_test.go
+++ b/pushrules/pushrules_test.go
@@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) {
},
}
pushRuleset, err := pushrules.EventToPushRules(evt)
- assert.NoError(t, err)
+ assert.Nil(t, err)
assert.NotNil(t, pushRuleset)
assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{})
diff --git a/requests.go b/requests.go
index cc8b7266..09e4b3cd 100644
--- a/requests.go
+++ b/requests.go
@@ -2,7 +2,6 @@ package mautrix
import (
"encoding/json"
- "fmt"
"strconv"
"time"
@@ -40,40 +39,20 @@ const (
type Direction rune
-func (d Direction) MarshalJSON() ([]byte, error) {
- return json.Marshal(string(d))
-}
-
-func (d *Direction) UnmarshalJSON(data []byte) error {
- var str string
- if err := json.Unmarshal(data, &str); err != nil {
- return err
- }
- switch str {
- case "f":
- *d = DirectionForward
- case "b":
- *d = DirectionBackward
- default:
- return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str)
- }
- return nil
-}
-
const (
DirectionForward Direction = 'f'
DirectionBackward Direction = 'b'
)
// ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
-type ReqRegister[UIAType any] struct {
+type ReqRegister struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
InhibitLogin bool `json:"inhibit_login,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
- Auth UIAType `json:"auth,omitempty"`
+ Auth interface{} `json:"auth,omitempty"`
// Type for registration, only used for appservice user registrations
// https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions
@@ -141,12 +120,11 @@ type ReqCreateRoom struct {
InitialState []*event.Event `json:"initial_state,omitempty"`
Preset string `json:"preset,omitempty"`
IsDirect bool `json:"is_direct,omitempty"`
- RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ RoomVersion event.RoomVersion `json:"room_version,omitempty"`
PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"`
MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"`
- MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"`
BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"`
BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"`
BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"`
@@ -183,11 +161,6 @@ type ReqKnockRoom struct {
Reason string `json:"reason,omitempty"`
}
-type ReqSearchUserDirectory struct {
- SearchTerm string `json:"search_term"`
- Limit int `json:"limit,omitempty"`
-}
-
type ReqMutualRooms struct {
From string `json:"-"`
}
@@ -320,11 +293,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
return ""
}
-type UploadCrossSigningKeysReq[UIAType any] struct {
+type UploadCrossSigningKeysReq struct {
Master CrossSigningKeys `json:"master_key"`
SelfSigning CrossSigningKeys `json:"self_signing_key"`
UserSigning CrossSigningKeys `json:"user_signing_key"`
- Auth UIAType `json:"auth,omitempty"`
+ Auth interface{} `json:"auth,omitempty"`
}
type KeyMap map[id.DeviceKeyID]string
@@ -367,23 +340,18 @@ type ReqSendToDevice struct {
}
type ReqSendEvent struct {
- Timestamp int64
- TransactionID string
- UnstableDelay time.Duration
- UnstableStickyDuration time.Duration
- DontEncrypt bool
- MeowEventID id.EventID
-}
+ Timestamp int64
+ TransactionID string
+ UnstableDelay time.Duration
-type ReqDelayedEvents struct {
- DelayID id.DelayID `json:"-"`
- Status event.DelayStatus `json:"-"`
- NextBatch string `json:"-"`
+ DontEncrypt bool
+
+ MeowEventID id.EventID
}
type ReqUpdateDelayedEvent struct {
- DelayID id.DelayID `json:"-"`
- Action event.DelayAction `json:"action"`
+ DelayID string `json:"-"`
+ Action string `json:"action"` // TODO use enum
}
// ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid
@@ -392,14 +360,14 @@ type ReqDeviceInfo struct {
}
// ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid
-type ReqDeleteDevice[UIAType any] struct {
- Auth UIAType `json:"auth,omitempty"`
+type ReqDeleteDevice struct {
+ Auth interface{} `json:"auth,omitempty"`
}
// ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices
-type ReqDeleteDevices[UIAType any] struct {
+type ReqDeleteDevices struct {
Devices []id.DeviceID `json:"devices"`
- Auth UIAType `json:"auth,omitempty"`
+ Auth interface{} `json:"auth,omitempty"`
}
type ReqPutPushRule struct {
@@ -411,6 +379,18 @@ type ReqPutPushRule struct {
Pattern string `json:"pattern"`
}
+// Deprecated: MSC2716 was abandoned
+type ReqBatchSend struct {
+ PrevEventID id.EventID `json:"-"`
+ BatchID id.BatchID `json:"-"`
+
+ BeeperNewMessages bool `json:"-"`
+ BeeperMarkReadBy id.UserID `json:"-"`
+
+ StateEventsAtStart []*event.Event `json:"state_events_at_start"`
+ Events []*event.Event `json:"events"`
+}
+
type ReqBeeperBatchSend struct {
// ForwardIfNoMessages should be set to true if the batch should be forward
// backfilled if there are no messages currently in the room.
@@ -606,13 +586,3 @@ func (rgr *ReqGetRelations) Query() map[string]string {
}
return query
}
-
-// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
-type ReqSuspend struct {
- Suspended bool `json:"suspended"`
-}
-
-// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
-type ReqLocked struct {
- Locked bool `json:"locked"`
-}
diff --git a/responses.go b/responses.go
index 4fbe1fbc..20d02af5 100644
--- a/responses.go
+++ b/responses.go
@@ -6,14 +6,12 @@ import (
"fmt"
"maps"
"reflect"
- "slices"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -106,22 +104,11 @@ type RespContext struct {
type RespSendEvent struct {
EventID id.EventID `json:"event_id"`
- UnstableDelayID id.DelayID `json:"delay_id,omitempty"`
+ UnstableDelayID string `json:"delay_id,omitempty"`
}
type RespUpdateDelayedEvent struct{}
-type RespDelayedEvents struct {
- Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"`
- Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"`
- NextBatch string `json:"next_batch,omitempty"`
-
- // Deprecated: Synapse implementation still returns this
- DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"`
- // Deprecated: Synapse implementation still returns this
- FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"`
-}
-
type RespRedactUserEvents struct {
IsMoreEvents bool `json:"is_more_events"`
RedactedEvents struct {
@@ -223,52 +210,25 @@ func (r *RespUserProfile) MarshalJSON() ([]byte, error) {
} else {
delete(marshalMap, "avatar_url")
}
- return json.Marshal(marshalMap)
-}
-
-type RespSearchUserDirectory struct {
- Limited bool `json:"limited"`
- Results []*UserDirectoryEntry `json:"results"`
-}
-
-type UserDirectoryEntry struct {
- RespUserProfile
- UserID id.UserID `json:"user_id"`
-}
-
-func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error {
- err := r.RespUserProfile.UnmarshalJSON(data)
- if err != nil {
- return err
- }
- userIDStr, _ := r.Extra["user_id"].(string)
- r.UserID = id.UserID(userIDStr)
- delete(r.Extra, "user_id")
- return nil
-}
-
-func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) {
- if r.Extra == nil {
- r.Extra = make(map[string]any)
- }
- r.Extra["user_id"] = r.UserID.String()
- return r.RespUserProfile.MarshalJSON()
+ return json.Marshal(r.Extra)
}
type RespMutualRooms struct {
Joined []id.RoomID `json:"joined"`
NextBatch string `json:"next_batch,omitempty"`
- Count int `json:"count,omitempty"`
}
type RespRoomSummary struct {
PublicRoomInfo
- Membership event.Membership `json:"membership,omitempty"`
+ Membership event.Membership `json:"membership,omitempty"`
+ RoomVersion event.RoomVersion `json:"room_version,omitempty"`
+ Encryption id.Algorithm `json:"encryption,omitempty"`
+ AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
- UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
- UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"`
- UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
+ UnstableRoomVersion event.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
+ UnstableRoomVersionOld event.RoomVersion `json:"im.nheko.summary.version,omitempty"`
+ UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"`
}
// RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable
@@ -342,24 +302,6 @@ type LazyLoadSummary struct {
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
}
-func (lls *LazyLoadSummary) MemberCount() int {
- if lls == nil {
- return 0
- }
- return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount)
-}
-
-func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool {
- if lls == other {
- return true
- } else if lls == nil || other == nil {
- return false
- }
- return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) &&
- ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) &&
- slices.Equal(lls.Heroes, other.Heroes)
-}
-
type SyncEventsList struct {
Events []*event.Event `json:"events,omitempty"`
}
@@ -455,7 +397,7 @@ type BeeperInboxPreviewEvent struct {
type SyncJoinedRoom struct {
Summary LazyLoadSummary `json:"summary"`
State SyncEventsList `json:"state"`
- StateAfter *SyncEventsList `json:"state_after,omitempty"`
+ StateAfter *SyncEventsList `json:"org.matrix.msc4222.state_after,omitempty"`
Timeline SyncTimeline `json:"timeline"`
Ephemeral SyncEventsList `json:"ephemeral"`
AccountData SyncEventsList `json:"account_data"`
@@ -546,19 +488,30 @@ type RespDeviceInfo struct {
LastSeenTS int64 `json:"last_seen_ts"`
}
+// Deprecated: MSC2716 was abandoned
+type RespBatchSend struct {
+ StateEventIDs []id.EventID `json:"state_event_ids"`
+ EventIDs []id.EventID `json:"event_ids"`
+
+ InsertionEventID id.EventID `json:"insertion_event_id"`
+ BatchEventID id.EventID `json:"batch_event_id"`
+ BaseInsertionEventID id.EventID `json:"base_insertion_event_id"`
+
+ NextBatchID id.BatchID `json:"next_batch_id"`
+}
+
type RespBeeperBatchSend struct {
EventIDs []id.EventID `json:"event_ids"`
}
// RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities
type RespCapabilities struct {
- RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
- ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
- SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
- SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
- ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
- GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
- UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"`
+ RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`
+ ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"`
+ SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"`
+ SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"`
+ ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"`
+ GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"`
Custom map[string]interface{} `json:"-"`
}
@@ -667,11 +620,6 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool {
return available
}
-type CapUnstableAccountModeration struct {
- Suspend bool `json:"suspend"`
- Lock bool `json:"lock"`
-}
-
type RespPublicRooms struct {
Chunk []*PublicRoomInfo `json:"chunk"`
NextBatch string `json:"next_batch,omitempty"`
@@ -690,10 +638,6 @@ type PublicRoomInfo struct {
RoomType event.RoomType `json:"room_type"`
Topic string `json:"topic,omitempty"`
WorldReadable bool `json:"world_readable"`
-
- RoomVersion id.RoomVersion `json:"room_version,omitempty"`
- Encryption id.Algorithm `json:"encryption,omitempty"`
- AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
}
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
@@ -704,7 +648,12 @@ type RespHierarchy struct {
type ChildRoomsChunk struct {
PublicRoomInfo
- ChildrenState []*event.Event `json:"children_state"`
+ ChildrenState []StrippedStateWithTime `json:"children_state"`
+}
+
+type StrippedStateWithTime struct {
+ event.StrippedState
+ Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
}
type RespAppservicePing struct {
@@ -767,33 +716,3 @@ type RespGetRelations struct {
PrevBatch string `json:"prev_batch,omitempty"`
RecursionDepth int `json:"recursion_depth,omitempty"`
}
-
-// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
-type RespSuspended struct {
- Suspended bool `json:"suspended"`
-}
-
-// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323
-type RespLocked struct {
- Locked bool `json:"locked"`
-}
-
-type ConnectionInfo struct {
- IP string `json:"ip,omitempty"`
- LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"`
- UserAgent string `json:"user_agent,omitempty"`
-}
-
-type SessionInfo struct {
- Connections []ConnectionInfo `json:"connections,omitempty"`
-}
-
-type DeviceInfo struct {
- Sessions []SessionInfo `json:"sessions,omitempty"`
-}
-
-// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
-type RespWhoIs struct {
- UserID id.UserID `json:"user_id,omitempty"`
- Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"`
-}
diff --git a/responses_test.go b/responses_test.go
index 73d82635..b23d85ad 100644
--- a/responses_test.go
+++ b/responses_test.go
@@ -8,6 +8,7 @@ package mautrix_test
import (
"encoding/json"
+ "fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -85,6 +86,7 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) {
var caps mautrix.RespCapabilities
err := json.Unmarshal([]byte(sampleData), &caps)
require.NoError(t, err)
+ fmt.Println(caps)
require.NotNil(t, caps.RoomVersions)
assert.Equal(t, "9", caps.RoomVersions.Default)
diff --git a/room.go b/room.go
index 4292bff5..c3ddb7e6 100644
--- a/room.go
+++ b/room.go
@@ -5,6 +5,8 @@ import (
"maunium.net/go/mautrix/id"
)
+type RoomStateMap = map[event.Type]map[string]*event.Event
+
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@@ -23,8 +25,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
- stateEventMap := room.State[eventType]
- evt := stateEventMap[stateKey]
+ stateEventMap, _ := room.State[eventType]
+ evt, _ := stateEventMap[stateKey]
return evt
}
diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go
index 11957dfa..4a220a2b 100644
--- a/sqlstatestore/statestore.go
+++ b/sqlstatestore/statestore.go
@@ -62,9 +62,6 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
- if userID == "" {
- return fmt.Errorf("user ID is empty")
- }
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
@@ -185,11 +182,6 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID,
}
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- } else if userID == "" {
- return fmt.Errorf("user ID is empty")
- }
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
@@ -222,11 +214,6 @@ func (u *userProfileRow) GetMassInsertValues() [5]any {
var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- } else if userID == "" {
- return fmt.Errorf("user ID is empty")
- }
var nameSkeleton []byte
if !store.DisableNameDisambiguation && len(member.Displayname) > 0 {
nameSkeletonArr := confusable.SkeletonHash(member.Displayname)
@@ -248,9 +235,6 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room
const userProfileMassInsertBatchSize = 500
func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- }
return store.DoTxn(ctx, nil, func(ctx context.Context) error {
err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
if err != nil {
@@ -321,9 +305,6 @@ func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.Roo
}
func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true)
ON CONFLICT (room_id) DO UPDATE SET members_fetched=true
@@ -353,9 +334,6 @@ func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID)
}
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- }
contentBytes, err := json.Marshal(content)
if err != nil {
return fmt.Errorf("failed to marshal content JSON: %w", err)
@@ -370,7 +348,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
- QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID).
+ QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
@@ -393,9 +371,6 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (
}
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- }
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
@@ -404,92 +379,89 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID
}
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
- levels = &event.PowerLevelsEventContent{}
err = store.
- QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID).
- Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent})
+ QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
+ Scan(&dbutil.JSON{Data: &levels})
if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- if levels.CreateEvent != nil {
- err = levels.CreateEvent.Content.ParseRaw(event.StateCreate)
+ err = nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
+ if store.Dialect == dbutil.Postgres {
+ var powerLevel int
+ err := store.
+ QueryRow(ctx, `
+ SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
+ FROM mx_room_state WHERE room_id=$1
+ `, roomID, userID).
+ Scan(&powerLevel)
+ return powerLevel, err
+ } else {
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
+ }
+ return levels.GetUserLevel(userID), nil
}
- return levels.GetUserLevel(userID), nil
}
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return 0, err
+ if store.Dialect == dbutil.Postgres {
+ defaultType := "events_default"
+ defaultValue := 0
+ if eventType.IsState() {
+ defaultType = "state_default"
+ defaultValue = 50
+ }
+ var powerLevel int
+ err := store.
+ QueryRow(ctx, `
+ SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
+ FROM mx_room_state WHERE room_id=$1
+ `, roomID, eventType.Type, defaultType, defaultValue).
+ Scan(&powerLevel)
+ if errors.Is(err, sql.ErrNoRows) {
+ err = nil
+ powerLevel = defaultValue
+ }
+ return powerLevel, err
+ } else {
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return 0, err
+ }
+ return levels.GetEventLevel(eventType), nil
}
- return levels.GetEventLevel(eventType), nil
}
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
- levels, err := store.GetPowerLevels(ctx, roomID)
- if err != nil {
- return false, err
+ if store.Dialect == dbutil.Postgres {
+ defaultType := "events_default"
+ defaultValue := 0
+ if eventType.IsState() {
+ defaultType = "state_default"
+ defaultValue = 50
+ }
+ var hasPower bool
+ err := store.
+ QueryRow(ctx, `SELECT
+ COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
+ >=
+ COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
+ FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
+ Scan(&hasPower)
+ if errors.Is(err, sql.ErrNoRows) {
+ err = nil
+ hasPower = defaultValue == 0
+ }
+ return hasPower, err
+ } else {
+ levels, err := store.GetPowerLevels(ctx, roomID)
+ if err != nil {
+ return false, err
+ }
+ return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
}
- return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
-}
-
-func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
- if evt.Type != event.StateCreate {
- return fmt.Errorf("invalid event type for create event: %s", evt.Type)
- } else if evt.RoomID == "" {
- return fmt.Errorf("room ID is empty")
- }
- _, err := store.Exec(ctx, `
- INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2)
- ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event
- `, evt.RoomID, dbutil.JSON{Data: evt})
- return err
-}
-
-func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) {
- err = store.
- QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID).
- Scan(&dbutil.JSON{Data: &evt})
- if errors.Is(err, sql.ErrNoRows) {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- if evt != nil {
- err = evt.Content.ParseRaw(event.StateCreate)
- }
- return
-}
-
-func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error {
- if roomID == "" {
- return fmt.Errorf("room ID is empty")
- }
- _, err := store.Exec(ctx, `
- INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2)
- ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules
- `, roomID, dbutil.JSON{Data: rules})
- return err
-}
-
-func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) {
- levels = &event.JoinRulesEventContent{}
- err = store.
- QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID).
- Scan(&dbutil.JSON{Data: &levels})
- if errors.Is(err, sql.ErrNoRows) {
- levels = nil
- err = nil
- }
- return
}
diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql
index 4679f1c6..a58cc56a 100644
--- a/sqlstatestore/v00-latest-revision.sql
+++ b/sqlstatestore/v00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v10 (compatible with v3+): Latest revision
+-- v0 -> v7 (compatible with v3+): Latest revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
@@ -26,7 +26,5 @@ CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels jsonb,
encryption jsonb,
- create_event jsonb,
- join_rules jsonb,
members_fetched BOOLEAN NOT NULL DEFAULT false
);
diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql
deleted file mode 100644
index 9f1b55c9..00000000
--- a/sqlstatestore/v08-create-event.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v8 (compatible with v3+): Add create event to room state table
-ALTER TABLE mx_room_state ADD COLUMN create_event jsonb;
diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql
deleted file mode 100644
index ca951068..00000000
--- a/sqlstatestore/v09-clear-empty-room-ids.sql
+++ /dev/null
@@ -1,3 +0,0 @@
--- v9 (compatible with v3+): Clear invalid rows
-DELETE FROM mx_room_state WHERE room_id='';
-DELETE FROM mx_user_profile WHERE room_id='' OR user_id='';
diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql
deleted file mode 100644
index 3074c46a..00000000
--- a/sqlstatestore/v10-join-rules.sql
+++ /dev/null
@@ -1,2 +0,0 @@
--- v10 (compatible with v3+): Add join rules to room state table
-ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb;
diff --git a/statestore.go b/statestore.go
index 2bd498dd..e728b885 100644
--- a/statestore.go
+++ b/statestore.go
@@ -34,12 +34,6 @@ type StateStore interface {
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
- SetCreate(ctx context.Context, evt *event.Event) error
- GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error)
-
- GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error)
- SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error
-
HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error)
MarkMembersFetched(ctx context.Context, roomID id.RoomID) error
GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error)
@@ -74,13 +68,9 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
- case *event.CreateEventContent:
- err = store.SetCreate(ctx, evt)
- case *event.JoinRulesEventContent:
- err = store.SetJoinRules(ctx, evt.RoomID, content)
default:
switch evt.Type {
- case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate:
+ case event.StateMember, event.StatePowerLevels, event.StateEncryption:
zerolog.Ctx(ctx).Warn().
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
@@ -111,14 +101,11 @@ type MemoryStateStore struct {
MembersFetched map[id.RoomID]bool `json:"members_fetched"`
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
- Create map[id.RoomID]*event.Event `json:"create"`
- JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
encryptionLock sync.RWMutex
- joinRulesLock sync.RWMutex
}
func NewMemoryStateStore() StateStore {
@@ -128,8 +115,6 @@ func NewMemoryStateStore() StateStore {
MembersFetched: make(map[id.RoomID]bool),
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
- Create: make(map[id.RoomID]*event.Event),
- JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent),
}
}
@@ -313,9 +298,6 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
- if levels != nil && levels.CreateEvent == nil {
- levels.CreateEvent = store.Create[roomID]
- }
store.powerLevelsLock.RUnlock()
return
}
@@ -332,23 +314,6 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
-func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error {
- store.powerLevelsLock.Lock()
- store.Create[evt.RoomID] = evt
- if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil {
- pls.CreateEvent = evt
- }
- store.powerLevelsLock.Unlock()
- return nil
-}
-
-func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) {
- store.powerLevelsLock.RLock()
- evt := store.Create[roomID]
- store.powerLevelsLock.RUnlock()
- return evt, nil
-}
-
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
@@ -362,19 +327,6 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R
return store.Encryption[roomID], nil
}
-func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error {
- store.joinRulesLock.Lock()
- store.JoinRules[roomID] = content
- store.joinRulesLock.Unlock()
- return nil
-}
-
-func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) {
- store.joinRulesLock.RLock()
- defer store.joinRulesLock.RUnlock()
- return store.JoinRules[roomID], nil
-}
-
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go
index 0925b748..a09ba174 100644
--- a/synapseadmin/roomapi.go
+++ b/synapseadmin/roomapi.go
@@ -75,7 +75,8 @@ type RespListRooms struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
var resp RespListRooms
- reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
+ var reqURL string
+ reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
_, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
@@ -116,7 +117,6 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to
type ReqDeleteRoom struct {
Purge bool `json:"purge,omitempty"`
- ForcePurge bool `json:"force_purge,omitempty"`
Block bool `json:"block,omitempty"`
Message string `json:"message,omitempty"`
RoomName string `json:"room_name,omitempty"`
diff --git a/sync.go b/sync.go
index 598df8e0..9a2b9edf 100644
--- a/sync.go
+++ b/sync.go
@@ -90,7 +90,6 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
}
}()
- ctx = context.WithValue(ctx, SyncTokenContextKey, since)
for _, listener := range s.syncListeners {
if !listener(ctx, res, since) {
@@ -264,7 +263,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
- var inviteState []*event.Event
+ var inviteState []event.StrippedState
var inviteEvt *event.Event
for _, evt := range meta.State.Events {
if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() {
@@ -272,7 +271,12 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string
} else {
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
- inviteState = append(inviteState, evt)
+ inviteState = append(inviteState, event.StrippedState{
+ Content: evt.Content,
+ Type: evt.Type,
+ StateKey: evt.GetStateKey(),
+ Sender: evt.Sender,
+ })
}
}
if inviteEvt != nil {
diff --git a/url.go b/url.go
index 91b3d49d..d888956a 100644
--- a/url.go
+++ b/url.go
@@ -98,8 +98,10 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
- for k, v := range urlQuery {
- q.Set(k, v)
+ if urlQuery != nil {
+ for k, v := range urlQuery {
+ q.Set(k, v)
+ }
}
})
}
diff --git a/version.go b/version.go
index f00bbf39..6b8af5ef 100644
--- a/version.go
+++ b/version.go
@@ -4,11 +4,10 @@ import (
"fmt"
"regexp"
"runtime"
- "runtime/debug"
"strings"
)
-const Version = "v0.26.3"
+const Version = "v0.24.2"
var GoModVersion = ""
var Commit = ""
@@ -16,20 +15,11 @@ var VersionWithCommit = Version
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
+var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
+
func init() {
- if GoModVersion == "" {
- info, _ := debug.ReadBuildInfo()
- if info != nil {
- for _, mod := range info.Deps {
- if mod.Path == "maunium.net/go/mautrix" {
- GoModVersion = mod.Version
- break
- }
- }
- }
- }
if GoModVersion != "" {
- match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion)
+ match := goModVersionRegex.FindStringSubmatch(GoModVersion)
if match != nil {
Commit = match[1]
}
diff --git a/versions.go b/versions.go
index 61b2e4ea..f87bddda 100644
--- a/versions.go
+++ b/versions.go
@@ -60,28 +60,20 @@ type UnstableFeature struct {
}
var (
- FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
- FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
- FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
- FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
- FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/}
- FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
- FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
- FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
- FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/}
- FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
- FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
- FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
+ FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
+ FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
+ FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
+ FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
- BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
- BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
- BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
- BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
- BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
- BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
- BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
- BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"}
- BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"}
+ BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
+ BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
+ BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
+ BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
+ BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
+ BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
+ BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
@@ -125,8 +117,6 @@ var (
SpecV113 = MustParseSpecVersion("v1.13")
SpecV114 = MustParseSpecVersion("v1.14")
SpecV115 = MustParseSpecVersion("v1.15")
- SpecV116 = MustParseSpecVersion("v1.16")
- SpecV117 = MustParseSpecVersion("v1.17")
)
func (svf SpecVersionFormat) String() string {