diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index dc4f17e2..c0add220 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -10,12 +10,12 @@ jobs:
runs-on: ubuntu-latest
name: Lint (latest)
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
- go-version: "1.25"
+ go-version: "1.26"
cache: true
- name: Install libolm
@@ -24,6 +24,7 @@ jobs:
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
+ go install honnef.co/go/tools/cmd/staticcheck@latest
export PATH="$HOME/go/bin:$PATH"
- name: Run pre-commit
@@ -34,14 +35,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- go-version: ["1.24", "1.25"]
- name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm)
+ go-version: ["1.25", "1.26"]
+ name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm)
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go ${{ matrix.go-version }}
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}
cache: true
@@ -61,7 +62,6 @@ jobs:
run: go test -json -v ./... 2>&1 | gotestfmt
- name: Test (jsonv2)
- if: matrix.go-version == '1.25'
env:
GOEXPERIMENT: jsonv2
run: go test -json -v ./... 2>&1 | gotestfmt
@@ -71,14 +71,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- go-version: ["1.24", "1.25"]
- name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm)
+ go-version: ["1.25", "1.26"]
+ name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm)
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go ${{ matrix.go-version }}
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}
cache: true
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 578349c9..9a9e7375 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -17,7 +17,7 @@ jobs:
lock-stale:
runs-on: ubuntu-latest
steps:
- - uses: dessant/lock-threads@v5
+ - uses: dessant/lock-threads@v6
id: lock
with:
issue-inactive-days: 90
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0b9785ae..616fccb2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v5.0.0
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@@ -9,7 +9,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/tekwizely/pre-commit-golang
- rev: v1.0.0-rc.1
+ rev: v1.0.0-rc.4
hooks:
- id: go-imports-repo
args:
@@ -18,8 +18,7 @@ repos:
- "-w"
- id: go-vet-repo-mod
- id: go-mod-tidy
- # TODO enable this
- #- id: go-staticcheck-repo-mod
+ - id: go-staticcheck-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.4.2
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b6c0ff70..f2829199 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,80 @@
+## v0.26.3 (2026-02-16)
+
+* Bumped minimum Go version to 1.25.
+* *(client)* Added fields for sending [MSC4354] sticky events.
+* *(bridgev2)* Added automatic message request accepting when sending message.
+* *(mediaproxy)* Added support for federation thumbnail endpoint.
+* *(crypto/ssss)* Improved support for recovery keys with slightly broken
+ metadata.
+* *(crypto)* Changed key import to call session received callback even for
+ sessions that already exist in the database.
+* *(appservice)* Fixed building websocket URL accidentally using file path
+ separators instead of always `/`.
+* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field.
+* *(client)* Fixed incorrect context usage in async uploads.
+* *(crypto)* Fixed panic when passing invalid input to megolm message index
+ parser used for debugging.
+* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned
+ up properly.
+
+[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354
+
+## v0.26.2 (2026-01-16)
+
+* *(bridgev2)* Added chunked portal deletion to avoid database locks when
+ deleting large portals.
+* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata
+ as per [MSC4392].
+* *(bridgev2/login)* Added `default_value` for user input fields.
+* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested
+ HTTP client settings and to reset active connections of the network connector.
+* *(bridgev2)* Added interface to let network connectors get the provisioning
+ API HTTP router and add new endpoints.
+* *(event)* Added blurhash field to Beeper link preview objects.
+* *(event)* Added [MSC4391] support for bot commands.
+* *(event)* Dropped [MSC4332] support for bot commands.
+* *(client)* Changed media download methods to return an error if the provided
+ MXC URI is empty.
+* *(client)* Stabilized support for [MSC4323].
+* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events.
+* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with
+ a portal re-ID call.
+
+[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391
+[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392
+
+## v0.26.1 (2025-12-16)
+
+* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return
+ the mime type from the callback rather than in the return get media return
+ value. The callback can now also redirect the caller to a different file.
+* *(federation)* Added join/knock/leave functions
+ (thanks to [@nexy7574] in [#422]).
+* *(federation/eventauth)* Fixed various incorrect checks.
+* *(client)* Added backoff for retrying media uploads to external URLs
+ (with MSC3870).
+* *(bridgev2/config)* Added support for overriding config fields using
+ environment variables.
+* *(bridgev2/commands)* Added command to mute chat on remote network.
+* *(bridgev2)* Added interface for network connectors to redirect to a different
+ user ID when handling an invite from Matrix.
+* *(bridgev2)* Added interface for signaling message request status of portals.
+* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag
+ is set in chat info.
+* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if
+ bridging the new one is successful.
+* *(bridgev2/mxmain)* Improved error message when trying to run bridge with
+ pre-megabridge database when no database migration exists.
+* *(bridgev2)* Improved reliability of database migration when enabling split
+ portals.
+* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats.
+* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public
+ portal rooms.
+* *(bridgev2)* Stopped hardcoding room versions in favor of checking
+ server capabilities to determine appropriate `/createRoom` parameters.
+
+[#422]: https://github.com/mautrix/go/pull/422
+
## v0.26.0 (2025-11-16)
* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent`
@@ -43,7 +120,7 @@
* *(federation)* Fixed validating auth for requests with query params.
* *(federation/eventauth)* Fixed typo causing restricted joins to not work.
-[MSC416]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169
+[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169
[#411]: github.com/mautrix/go/pull/411
[#420]: github.com/mautrix/go/pull/420
[#426]: github.com/mautrix/go/pull/426
@@ -360,6 +437,7 @@
[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156
[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190
[#288]: https://github.com/mautrix/go/pull/288
+[@onestacked]: https://github.com/onestacked
## v0.22.0 (2024-11-16)
diff --git a/README.md b/README.md
index ac41ca78..b1a2edf8 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,9 @@
# mautrix-go
[](https://pkg.go.dev/maunium.net/go/mautrix)
-A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks),
-[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp)
+A Golang Matrix framework. Used by [gomuks](https://gomuks.app),
+[go-neb](https://github.com/matrix-org/go-neb),
+[mautrix-whatsapp](https://github.com/mautrix/whatsapp)
and others.
Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net)
@@ -13,9 +14,10 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or
In addition to the basic client API features the original project has, this framework also has:
* Appservice support (Intent API like mautrix-python, room state storage, etc)
-* End-to-end encryption support (incl. interactive SAS verification)
+* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc)
* High-level module for building puppeting bridges
-* High-level module for building chat clients
+* Partial federation module (making requests, PDU processing and event authorization)
+* A media proxy server which can be used to expose anything as a Matrix media repo
* Wrapper functions for the Synapse admin API
* Structs for parsing event content
* Helpers for parsing and generating Matrix HTML
diff --git a/appservice/intent.go b/appservice/intent.go
index e4d8e100..5d43f190 100644
--- a/appservice/intent.go
+++ b/appservice/intent.go
@@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
}
func (intent *IntentAPI) Register(ctx context.Context) error {
- _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{
+ _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{
Username: intent.Localpart,
Type: mautrix.AuthTypeAppservice,
InhibitLogin: true,
@@ -222,6 +222,17 @@ func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID,
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...)
}
+func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(ctx, roomID); err != nil {
+ return nil, err
+ }
+ if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
+ return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
+ }
+ contentJSON = intent.AddDoublePuppetValue(contentJSON)
+ return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...)
+}
+
// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead
func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
diff --git a/appservice/websocket.go b/appservice/websocket.go
index 1e401c53..ef65e65a 100644
--- a/appservice/websocket.go
+++ b/appservice/websocket.go
@@ -14,7 +14,7 @@ import (
"io"
"net/http"
"net/url"
- "path/filepath"
+ "path"
"strings"
"sync"
"sync/atomic"
@@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
var prefixMessage string
for unwrappedErr != nil {
errorData, jsonErr = json.Marshal(unwrappedErr)
- if errorData != nil && len(errorData) > 2 && jsonErr == nil {
+ if len(errorData) > 2 && jsonErr == nil {
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
prefixMessage = strings.TrimRight(prefixMessage, ": ")
break
@@ -374,7 +374,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
copiedURL := *as.hsURLForClient
parsed = &copiedURL
}
- parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
+ parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
if parsed.Scheme == "http" {
parsed.Scheme = "ws"
} else if parsed.Scheme == "https" {
diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go
index c84c2fd5..226adc90 100644
--- a/bridgev2/bridge.go
+++ b/bridgev2/bridge.go
@@ -16,6 +16,7 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exhttp"
"go.mau.fi/util/exsync"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
@@ -373,6 +374,42 @@ func (br *Bridge) StartLogins(ctx context.Context) error {
return nil
}
+func (br *Bridge) ResetNetworkConnections() {
+ nrn, ok := br.Network.(NetworkResettingNetwork)
+ if ok {
+ br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections")
+ nrn.ResetNetworkConnections()
+ return
+ }
+
+ br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually")
+ for _, login := range br.GetAllCachedUserLogins() {
+ login.Log.Debug().Msg("Disconnecting and recreating client for network reset")
+ ctx := login.Log.WithContext(br.BackgroundCtx)
+ login.Client.Disconnect()
+ err := login.recreateClient(ctx)
+ if err != nil {
+ login.Log.Err(err).Msg("Failed to recreate client during network reset")
+ login.BridgeState.Send(status.BridgeState{
+ StateEvent: status.StateUnknownError,
+ Error: "bridgev2-network-reset-fail",
+ Info: map[string]any{"go_error": err.Error()},
+ })
+ } else {
+ login.Client.Connect(ctx)
+ }
+ }
+ br.Log.Info().Msg("Finished resetting all user logins")
+}
+
+func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings {
+ mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings)
+ if ok {
+ return mchs.GetHTTPClientSettings()
+ }
+ return exhttp.SensibleClientSettings
+}
+
func (br *Bridge) IsStopping() bool {
return br.stopping.Load()
}
diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go
index 53282e41..eedae1e8 100644
--- a/bridgev2/bridgeconfig/backfill.go
+++ b/bridgev2/bridgeconfig/backfill.go
@@ -34,10 +34,12 @@ type BackfillQueueConfig struct {
MaxBatchesOverride map[string]int `yaml:"max_batches_override"`
}
-func (bqc *BackfillQueueConfig) GetOverride(name string) int {
- override, ok := bqc.MaxBatchesOverride[name]
- if !ok {
- return bqc.MaxBatches
+func (bqc *BackfillQueueConfig) GetOverride(names ...string) int {
+ for _, name := range names {
+ override, ok := bqc.MaxBatchesOverride[name]
+ if ok {
+ return override
+ }
}
- return override
+ return bqc.MaxBatches
}
diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go
index b1718f30..bd6b9c06 100644
--- a/bridgev2/bridgeconfig/config.go
+++ b/bridgev2/bridgeconfig/config.go
@@ -33,6 +33,8 @@ type Config struct {
Encryption EncryptionConfig `yaml:"encryption"`
Logging zeroconfig.Config `yaml:"logging"`
+ EnvConfigPrefix string `yaml:"env_config_prefix"`
+
ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"`
}
@@ -60,38 +62,40 @@ type CleanupOnLogouts struct {
}
type BridgeConfig struct {
- CommandPrefix string `yaml:"command_prefix"`
- PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
- PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
- AsyncEvents bool `yaml:"async_events"`
- SplitPortals bool `yaml:"split_portals"`
- ResendBridgeInfo bool `yaml:"resend_bridge_info"`
- NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
- BridgeStatusNotices string `yaml:"bridge_status_notices"`
- UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
- BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
- BridgeNotices bool `yaml:"bridge_notices"`
- TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
- OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
- MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
- DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
- CrossRoomReplies bool `yaml:"cross_room_replies"`
- OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
- RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"`
- KickMatrixUsers bool `yaml:"kick_matrix_users"`
- CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
- Relay RelayConfig `yaml:"relay"`
- Permissions PermissionConfig `yaml:"permissions"`
- Backfill BackfillConfig `yaml:"backfill"`
+ CommandPrefix string `yaml:"command_prefix"`
+ PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
+ PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
+ AsyncEvents bool `yaml:"async_events"`
+ SplitPortals bool `yaml:"split_portals"`
+ ResendBridgeInfo bool `yaml:"resend_bridge_info"`
+ NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
+ BridgeStatusNotices string `yaml:"bridge_status_notices"`
+ UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
+ UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"`
+ BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
+ BridgeNotices bool `yaml:"bridge_notices"`
+ TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
+ OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
+ MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
+ DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
+ CrossRoomReplies bool `yaml:"cross_room_replies"`
+ OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
+ RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"`
+ KickMatrixUsers bool `yaml:"kick_matrix_users"`
+ CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
+ Relay RelayConfig `yaml:"relay"`
+ Permissions PermissionConfig `yaml:"permissions"`
+ Backfill BackfillConfig `yaml:"backfill"`
}
type MatrixConfig struct {
- MessageStatusEvents bool `yaml:"message_status_events"`
- DeliveryReceipts bool `yaml:"delivery_receipts"`
- MessageErrorNotices bool `yaml:"message_error_notices"`
- SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
- FederateRooms bool `yaml:"federate_rooms"`
- UploadFileThreshold int64 `yaml:"upload_file_threshold"`
+ MessageStatusEvents bool `yaml:"message_status_events"`
+ DeliveryReceipts bool `yaml:"delivery_receipts"`
+ MessageErrorNotices bool `yaml:"message_error_notices"`
+ SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
+ FederateRooms bool `yaml:"federate_rooms"`
+ UploadFileThreshold int64 `yaml:"upload_file_threshold"`
+ GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"`
}
type AnalyticsConfig struct {
diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go
index 5a19b3ad..934613ca 100644
--- a/bridgev2/bridgeconfig/encryption.go
+++ b/bridgev2/bridgeconfig/encryption.go
@@ -16,6 +16,7 @@ type EncryptionConfig struct {
Require bool `yaml:"require"`
Appservice bool `yaml:"appservice"`
MSC4190 bool `yaml:"msc4190"`
+ MSC4392 bool `yaml:"msc4392"`
SelfSign bool `yaml:"self_sign"`
PlaintextMentions bool `yaml:"plaintext_mentions"`
diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go
index 898bf58a..9efe068e 100644
--- a/bridgev2/bridgeconfig/permissions.go
+++ b/bridgev2/bridgeconfig/permissions.go
@@ -41,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool {
_, hasExampleDomain := pc["example.com"]
_, hasExampleUser := pc["@admin:example.com"]
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
- if len(pc) <= exampleLen {
- return false
- }
- return true
+ return len(pc) > exampleLen
}
func (pc PermissionConfig) Get(userID id.UserID) Permissions {
diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go
index 0dbff802..92515ea0 100644
--- a/bridgev2/bridgeconfig/upgrade.go
+++ b/bridgev2/bridgeconfig/upgrade.go
@@ -33,6 +33,7 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key")
helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices")
helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect")
+ helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects")
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
helper.Copy(up.Bool, "bridge", "bridge_notices")
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
@@ -100,6 +101,7 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
helper.Copy(up.Bool, "matrix", "federate_rooms")
helper.Copy(up.Int, "matrix", "upload_file_threshold")
+ helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info")
helper.Copy(up.Str|up.Null, "analytics", "token")
helper.Copy(up.Str|up.Null, "analytics", "url")
@@ -161,6 +163,7 @@ func doUpgrade(helper up.Helper) {
} else {
helper.Copy(up.Bool, "encryption", "msc4190")
}
+ helper.Copy(up.Bool, "encryption", "msc4392")
helper.Copy(up.Bool, "encryption", "self_sign")
helper.Copy(up.Bool, "encryption", "allow_key_sharing")
if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" {
@@ -184,6 +187,8 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Int, "encryption", "rotation", "messages")
helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation")
+ helper.Copy(up.Str|up.Null, "env_config_prefix")
+
helper.Copy(up.Map, "logging")
}
@@ -211,6 +216,7 @@ var SpacedBlocks = [][]string{
{"backfill"},
{"double_puppet"},
{"encryption"},
+ {"env_config_prefix"},
{"logging"},
}
diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go
index 63d5876b..96d9fd5c 100644
--- a/bridgev2/bridgestate.go
+++ b/bridgev2/bridgestate.go
@@ -22,6 +22,8 @@ import (
"maunium.net/go/mautrix/format"
)
+var CatchBridgeStateQueuePanics = true
+
type BridgeStateQueue struct {
prevUnsent *status.BridgeState
prevSent *status.BridgeState
@@ -35,6 +37,8 @@ type BridgeStateQueue struct {
stopChan chan struct{}
stopReconnect atomic.Pointer[context.CancelFunc]
+
+ unknownErrorReconnects int
}
func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) {
@@ -84,23 +88,25 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() {
}
func (bsq *BridgeStateQueue) loop() {
- defer func() {
- err := recover()
- if err != nil {
- bsq.login.Log.Error().
- Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
- Any(zerolog.ErrorFieldName, err).
- Msg("Panic in bridge state loop")
- }
- }()
+ if CatchBridgeStateQueuePanics {
+ defer func() {
+ err := recover()
+ if err != nil {
+ bsq.login.Log.Error().
+ Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
+ Any(zerolog.ErrorFieldName, err).
+ Msg("Panic in bridge state loop")
+ }
+ }()
+ }
for state := range bsq.ch {
bsq.immediateSendBridgeState(state)
}
}
-func (bsq *BridgeStateQueue) scheduleNotice(ctx context.Context, triggeredBy status.BridgeState) {
+func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
- ctx = log.WithContext(bsq.bridge.BackgroundCtx)
+ ctx := log.WithContext(bsq.bridge.BackgroundCtx)
if !bsq.waitForTransientDisconnectReconnect(ctx) {
return
}
@@ -131,7 +137,7 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
if bsq.firstTransientDisconnect.IsZero() {
bsq.firstTransientDisconnect = time.Now()
}
- go bsq.scheduleNotice(ctx, state)
+ go bsq.scheduleNotice(state)
}
return
}
@@ -188,8 +194,14 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat
} else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError {
log.Debug().Msg("Not reconnecting as the previous state was not an unknown error")
return
+ } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects {
+ log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached")
+ return
}
- log.Info().Msg("Disconnecting and reconnecting login due to unknown error")
+ bsq.unknownErrorReconnects++
+ log.Info().
+ Int("reconnect_num", bsq.unknownErrorReconnects).
+ Msg("Disconnecting and reconnecting login due to unknown error")
bsq.login.Disconnect()
log.Debug().Msg("Disconnection finished, recreating client and reconnecting")
err := bsq.login.recreateClient(ctx)
diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go
index ad773ac8..1cae98fe 100644
--- a/bridgev2/commands/debug.go
+++ b/bridgev2/commands/debug.go
@@ -101,3 +101,25 @@ var CommandSendAccountData = &FullHandler{
RequiresPortal: true,
RequiresLogin: true,
}
+
+var CommandResetNetwork = &FullHandler{
+ Func: func(ce *Event) {
+ if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") {
+ nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork)
+ if ok {
+ nrn.ResetHTTPTransport()
+ } else {
+ ce.Reply("Network connector does not support resetting HTTP transport")
+ }
+ }
+ ce.Bridge.ResetNetworkConnections()
+ ce.React("✅️")
+ },
+ Name: "debug-reset-network",
+ Help: HelpMeta{
+ Section: HelpSectionAdmin,
+ Description: "Reset network connections to the remote network",
+ Args: "[--reset-transport]",
+ },
+ RequiresAdmin: true,
+}
diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go
index 80a7c733..96d62d3e 100644
--- a/bridgev2/commands/login.go
+++ b/bridgev2/commands/login.go
@@ -121,6 +121,7 @@ func fnLogin(ce *Event) {
ce.Reply("Failed to start login: %v", err)
return
}
+ ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process")
nextStep = checkLoginCommandDirectParams(ce, login, nextStep)
if nextStep != nil {
@@ -251,14 +252,19 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return fmt.Errorf("failed to upload image: %w", err)
}
content := &event.MessageEventContent{
- MsgType: event.MsgImage,
- FileName: "qr.png",
- URL: qrMXC,
- File: qrFile,
-
+ MsgType: event.MsgImage,
+ FileName: "qr.png",
+ URL: qrMXC,
+ File: qrFile,
Body: qr,
Format: event.FormatHTML,
FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)),
+ Info: &event.FileInfo{
+ MimeType: "image/png",
+ Width: qrSizePx,
+ Height: qrSizePx,
+ Size: len(qrData),
+ },
}
if *prevEventID != "" {
content.SetEdit(*prevEventID)
@@ -273,6 +279,36 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return nil
}
+func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
+ for _, att := range atts {
+ if att.FileName == "" {
+ return fmt.Errorf("missing attachment filename")
+ }
+ mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
+ if err != nil {
+ return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
+ }
+ content := &event.MessageEventContent{
+ MsgType: att.Type,
+ FileName: att.FileName,
+ URL: mxc,
+ File: file,
+ Info: &event.FileInfo{
+ MimeType: att.Info.MimeType,
+ Width: att.Info.Width,
+ Height: att.Info.Height,
+ Size: att.Info.Size,
+ },
+ Body: att.FileName,
+ }
+ _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+}
+
type contextKey int
const (
@@ -464,6 +500,7 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
}
func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
+ ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
if step.Instructions != "" {
ce.Reply(step.Instructions)
}
@@ -478,6 +515,10 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte
Override: override,
}).prompt(ce)
case bridgev2.LoginStepTypeUserInput:
+ err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
+ if err != nil {
+ ce.Reply("Failed to send attachments: %v", err)
+ }
(&userInputLoginCommandState{
Login: login.(bridgev2.LoginProcessUserInput),
RemainingFields: step.UserInputParams.Fields,
diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go
index 13a35687..391c3685 100644
--- a/bridgev2/commands/processor.go
+++ b/bridgev2/commands/processor.go
@@ -41,10 +41,11 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
}
proc.AddHandlers(
CommandHelp, CommandCancel,
- CommandRegisterPush, CommandSendAccountData, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
+ CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
+ CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
CommandSetRelay, CommandUnsetRelay,
- CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat,
+ CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
CommandSudo, CommandDoIn,
)
return proc
diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go
index af756c87..94c19739 100644
--- a/bridgev2/commands/relay.go
+++ b/bridgev2/commands/relay.go
@@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
}
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
var relay *bridgev2.UserLogin
- if len(ce.Args) == 0 {
+ if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
relay = ce.User.GetDefaultLogin()
isLoggedIn := relay != nil
if onlySetDefaultRelays {
@@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) {
}
}
} else {
- relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
+ var targetID networkid.UserLoginID
+ if ce.Portal.Receiver != "" {
+ targetID = ce.Portal.Receiver
+ if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
+ ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
+ return
+ }
+ } else {
+ targetID = networkid.UserLoginID(ce.Args[0])
+ }
+ relay = ce.Bridge.GetCachedUserLoginByID(targetID)
if relay == nil {
- ce.Reply("User login with ID `%s` not found", ce.Args[0])
+ ce.Reply("User login with ID `%s` not found", targetID)
return
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
// All good
diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go
index 24586387..c7b05a6e 100644
--- a/bridgev2/commands/startchat.go
+++ b/bridgev2/commands/startchat.go
@@ -80,7 +80,7 @@ var CommandStartChat = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
-func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
+func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
var remainingArgs []string
if len(ce.Args) > 1 {
remainingArgs = ce.Args[1:]
@@ -290,3 +290,44 @@ func fnSearch(ce *Event) {
}
ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n"))
}
+
+var CommandMute = &FullHandler{
+ Func: fnMute,
+ Name: "mute",
+ Aliases: []string{"unmute"},
+ Help: HelpMeta{
+ Section: HelpSectionChats,
+ Description: "Mute or unmute a chat on the remote network",
+ Args: "[duration]",
+ },
+ RequiresPortal: true,
+ RequiresLogin: true,
+ NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI],
+}
+
+func fnMute(ce *Event) {
+ _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats")
+ var mutedUntil int64
+ if ce.Command == "mute" {
+ mutedUntil = -1
+ if len(ce.Args) > 0 {
+ duration, err := time.ParseDuration(ce.Args[0])
+ if err != nil {
+ ce.Reply("Invalid duration: %v", err)
+ return
+ }
+ mutedUntil = time.Now().Add(duration).UnixMilli()
+ }
+ }
+ err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{
+ MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{
+ Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil},
+ Portal: ce.Portal,
+ },
+ })
+ if err != nil {
+ ce.Reply("Failed to %s chat: %v", ce.Command, err)
+ } else {
+ ce.React("✅️")
+ }
+}
diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go
index 0729cb83..05abddf0 100644
--- a/bridgev2/database/database.go
+++ b/bridgev2/database/database.go
@@ -7,13 +7,7 @@
package database
import (
- "encoding/json"
- "reflect"
- "strings"
-
"go.mau.fi/util/dbutil"
- "golang.org/x/exp/constraints"
- "golang.org/x/exp/maps"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -158,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
panic("bridge ID mismatch")
}
}
-
-func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
- if val, found := m[key]; found {
- floatVal, ok := val.(float64)
- if ok {
- return T(floatVal), true
- }
- tVal, ok := val.(T)
- if ok {
- return tVal, true
- }
- }
- return 0, false
-}
-
-func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
- err := json.Unmarshal(input, data)
- if err != nil {
- return err
- }
- err = json.Unmarshal(input, extra)
- if err != nil {
- return err
- }
- if *extra == nil {
- *extra = make(map[string]any)
- }
- return nil
-}
-
-func marshalMerge(data any, extra map[string]any) ([]byte, error) {
- if extra == nil {
- return json.Marshal(data)
- }
- merged := make(map[string]any)
- maps.Copy(merged, extra)
- dataRef := reflect.ValueOf(data).Elem()
- dataType := dataRef.Type()
- for _, field := range reflect.VisibleFields(dataType) {
- parts := strings.Split(field.Tag.Get("json"), ",")
- if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
- continue
- }
- fieldVal := dataRef.FieldByIndex(field.Index)
- if fieldVal.IsZero() {
- delete(merged, parts[0])
- } else {
- merged[parts[0]] = fieldVal.Interface()
- }
- }
- return json.Marshal(merged)
-}
diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go
index c32929ad..16af35ca 100644
--- a/bridgev2/database/ghost.go
+++ b/bridgev2/database/ghost.go
@@ -7,12 +7,17 @@
package database
import (
+ "bytes"
"context"
"encoding/hex"
+ "encoding/json"
+ "fmt"
"go.mau.fi/util/dbutil"
+ "go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
)
@@ -22,6 +27,55 @@ type GhostQuery struct {
*dbutil.QueryHelper[*Ghost]
}
+type ExtraProfile map[string]json.RawMessage
+
+func (ep *ExtraProfile) Set(key string, value any) error {
+ if key == "displayname" || key == "avatar_url" {
+ return fmt.Errorf("cannot set reserved profile key %q", key)
+ }
+ marshaled, err := json.Marshal(value)
+ if err != nil {
+ return err
+ }
+ if *ep == nil {
+ *ep = make(ExtraProfile)
+ }
+ (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled)
+ return nil
+}
+
+func (ep *ExtraProfile) With(key string, value any) *ExtraProfile {
+ exerrors.PanicIfNotNil(ep.Set(key, value))
+ return ep
+}
+
+func canonicalizeIfObject(data json.RawMessage) json.RawMessage {
+ if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
+ return canonicaljson.CanonicalJSONAssumeValid(data)
+ }
+ return data
+}
+
+func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) {
+ if len(*ep) == 0 {
+ return
+ }
+ if *dest == nil {
+ *dest = make(ExtraProfile)
+ }
+ for key, val := range *ep {
+ if key == "displayname" || key == "avatar_url" {
+ continue
+ }
+ existing, exists := (*dest)[key]
+ if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) {
+ (*dest)[key] = val
+ changed = true
+ }
+ }
+ return
+}
+
type Ghost struct {
BridgeID networkid.BridgeID
ID networkid.UserID
@@ -35,13 +89,14 @@ type Ghost struct {
ContactInfoSet bool
IsBot bool
Identifiers []string
+ ExtraProfile ExtraProfile
Metadata any
}
const (
getGhostBaseQuery = `
SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
+ name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
FROM ghost
`
getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2`
@@ -49,13 +104,14 @@ const (
insertGhostQuery = `
INSERT INTO ghost (
bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
+ name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
`
updateGhostQuery = `
UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6,
- name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12
+ name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10,
+ identifiers=$11, extra_profile=$12, metadata=$13
WHERE bridge_id=$1 AND id=$2
`
)
@@ -86,7 +142,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) {
&g.BridgeID, &g.ID,
&g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC,
&g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot,
- dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
+ dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
)
if err != nil {
return nil, err
@@ -116,6 +172,6 @@ func (g *Ghost) sqlVariables() []any {
g.BridgeID, g.ID,
g.Name, g.AvatarID, avatarHash, g.AvatarMXC,
g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot,
- dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
+ dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
}
}
diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go
index 9b3b1493..4fd599a8 100644
--- a/bridgev2/database/message.go
+++ b/bridgev2/database/message.go
@@ -11,9 +11,12 @@ import (
"crypto/sha256"
"database/sql"
"encoding/base64"
+ "fmt"
"strings"
+ "sync"
"time"
+ "github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -24,6 +27,7 @@ type MessageQuery struct {
BridgeID networkid.BridgeID
MetaType MetaTypeCreator
*dbutil.QueryHelper[*Message]
+ chunkDeleteLock sync.Mutex
}
type Message struct {
@@ -64,8 +68,8 @@ const (
getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`
getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5`
getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1`
- getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1`
- getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1`
+ getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1`
+ getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1`
getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4`
getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1`
@@ -96,6 +100,10 @@ const (
deleteMessagePartByRowIDQuery = `
DELETE FROM message WHERE bridge_id=$1 AND rowid=$2
`
+ deleteMessageChunkQuery = `
+ DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5
+ `
+ getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1`
)
func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) {
@@ -180,6 +188,85 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error {
return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID)
}
+func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) {
+ res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID)
+ if err != nil {
+ return 0, err
+ }
+ return res.RowsAffected()
+}
+
+func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) {
+ err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID)
+ return
+}
+
+const deleteChunkSize = 100_000
+
+func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error {
+ if mq.GetDB().Dialect != dbutil.SQLite {
+ return nil
+ }
+ log := zerolog.Ctx(ctx).With().
+ Str("action", "delete messages in chunks").
+ Stringer("portal_key", portal).
+ Logger()
+ if !mq.chunkDeleteLock.TryLock() {
+ log.Warn().Msg("Portal deletion lock is being held, waiting...")
+ mq.chunkDeleteLock.Lock()
+ log.Debug().Msg("Acquired portal deletion lock after waiting")
+ }
+ defer mq.chunkDeleteLock.Unlock()
+ total, err := mq.CountMessagesInPortal(ctx, portal)
+ if err != nil {
+ return fmt.Errorf("failed to count messages in portal: %w", err)
+ } else if total < deleteChunkSize/3 {
+ return nil
+ }
+ globalMaxRowID, err := mq.getMaxRowID(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get max row ID: %w", err)
+ }
+ log.Debug().
+ Int("total_count", total).
+ Int64("global_max_row_id", globalMaxRowID).
+ Msg("Portal has lots of messages, deleting in chunks to avoid database locks")
+ maxRowID := int64(deleteChunkSize)
+ globalMaxRowID += deleteChunkSize * 1.2
+ var dbTimeUsed time.Duration
+ globalStart := time.Now()
+ for total > 500 && maxRowID < globalMaxRowID {
+ start := time.Now()
+ count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID)
+ duration := time.Since(start)
+ dbTimeUsed += duration
+ if err != nil {
+ return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err)
+ }
+ total -= int(count)
+ maxRowID += deleteChunkSize
+ sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond))
+ log.Debug().
+ Int64("max_row_id", maxRowID).
+ Int64("deleted_count", count).
+ Int("remaining_count", total).
+ Dur("duration", duration).
+ Dur("sleep_time", sleepTime).
+ Msg("Deleted chunk of messages")
+ select {
+ case <-time.After(sleepTime):
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ log.Debug().
+ Int("remaining_count", total).
+ Dur("db_time_used", dbTimeUsed).
+ Dur("total_duration", time.Since(globalStart)).
+ Msg("Finished chunked delete of messages in portal")
+ return nil
+}
+
func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) {
err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count)
return
diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go
index a230df19..0e6be286 100644
--- a/bridgev2/database/portal.go
+++ b/bridgev2/database/portal.go
@@ -56,30 +56,31 @@ type Portal struct {
networkid.PortalKey
MXID id.RoomID
- ParentKey networkid.PortalKey
- RelayLoginID networkid.UserLoginID
- OtherUserID networkid.UserID
- Name string
- Topic string
- AvatarID networkid.AvatarID
- AvatarHash [32]byte
- AvatarMXC id.ContentURIString
- NameSet bool
- TopicSet bool
- AvatarSet bool
- NameIsCustom bool
- InSpace bool
- RoomType RoomType
- Disappear DisappearingSetting
- CapState CapabilityState
- Metadata any
+ ParentKey networkid.PortalKey
+ RelayLoginID networkid.UserLoginID
+ OtherUserID networkid.UserID
+ Name string
+ Topic string
+ AvatarID networkid.AvatarID
+ AvatarHash [32]byte
+ AvatarMXC id.ContentURIString
+ NameSet bool
+ TopicSet bool
+ AvatarSet bool
+ NameIsCustom bool
+ InSpace bool
+ MessageRequest bool
+ RoomType RoomType
+ Disappear DisappearingSetting
+ CapState CapabilityState
+ Metadata any
}
const (
getPortalBaseQuery = `
SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
- name_set, topic_set, avatar_set, name_is_custom, in_space,
+ name_set, topic_set, avatar_set, name_is_custom, in_space, message_request,
room_type, disappear_type, disappear_timer, cap_state,
metadata
FROM portal
@@ -88,7 +89,7 @@ const (
getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')`
getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL`
- getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND receiver=''`
+ getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC`
getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2`
getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3`
getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1`
@@ -101,11 +102,11 @@ const (
bridge_id, id, receiver, mxid,
parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
- name_set, avatar_set, topic_set, name_is_custom, in_space,
+ name_set, avatar_set, topic_set, name_is_custom, in_space, message_request,
room_type, disappear_type, disappear_timer, cap_state,
metadata, relay_bridge_id
) VALUES (
- $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23,
+ $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24,
CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END
)
`
@@ -114,8 +115,8 @@ const (
SET mxid=$4, parent_id=$5, parent_receiver=$6,
relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END,
other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13,
- name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18,
- room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23
+ name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19,
+ room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24
WHERE bridge_id=$1 AND id=$2 AND receiver=$3
`
deletePortalQuery = `
@@ -148,7 +149,10 @@ const (
)
`
fixParentsAfterSplitPortalMigrationQuery = `
- UPDATE portal SET parent_receiver=receiver WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'';
+ UPDATE portal
+ SET parent_receiver=receiver
+ WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''
+ AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver);
`
)
@@ -238,7 +242,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
&p.BridgeID, &p.ID, &p.Receiver, &mxid,
&parentID, &parentReceiver, &relayLoginID, &otherUserID,
&p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC,
- &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace,
+ &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest,
&p.RoomType, &disappearType, &disappearTimer,
dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata},
)
@@ -285,7 +289,7 @@ func (p *Portal) sqlVariables() []any {
p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID),
dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID),
p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC,
- p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace,
+ p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest,
p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer),
dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata},
}
diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql
index efde8816..6092dc24 100644
--- a/bridgev2/database/upgrades/00-latest.sql
+++ b/bridgev2/database/upgrades/00-latest.sql
@@ -1,4 +1,4 @@
--- v0 -> v24 (compatible with v9+): Latest revision
+-- v0 -> v27 (compatible with v9+): Latest revision
CREATE TABLE "user" (
bridge_id TEXT NOT NULL,
mxid TEXT NOT NULL,
@@ -48,6 +48,7 @@ CREATE TABLE portal (
topic_set BOOLEAN NOT NULL,
name_is_custom BOOLEAN NOT NULL DEFAULT false,
in_space BOOLEAN NOT NULL,
+ message_request BOOLEAN NOT NULL DEFAULT false,
room_type TEXT NOT NULL,
disappear_type TEXT,
disappear_timer BIGINT,
@@ -64,6 +65,7 @@ CREATE TABLE portal (
ON DELETE SET NULL ON UPDATE CASCADE
);
CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid);
+CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
CREATE TABLE ghost (
bridge_id TEXT NOT NULL,
@@ -78,6 +80,7 @@ CREATE TABLE ghost (
contact_info_set BOOLEAN NOT NULL,
is_bot BOOLEAN NOT NULL,
identifiers jsonb NOT NULL,
+ extra_profile jsonb,
metadata jsonb NOT NULL,
PRIMARY KEY (bridge_id, id)
@@ -138,6 +141,7 @@ CREATE TABLE disappearing_message (
REFERENCES portal (bridge_id, mxid)
ON DELETE CASCADE
);
+CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
CREATE TABLE reaction (
bridge_id TEXT NOT NULL,
diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql
new file mode 100644
index 00000000..b9d82a7a
--- /dev/null
+++ b/bridgev2/database/upgrades/25-message-requests.sql
@@ -0,0 +1,2 @@
+-- v25 (compatible with v9+): Flag for message request portals
+ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false;
diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql
new file mode 100644
index 00000000..ae5d8cad
--- /dev/null
+++ b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql
@@ -0,0 +1,3 @@
+-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents
+CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
+CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql
new file mode 100644
index 00000000..e8e0549a
--- /dev/null
+++ b/bridgev2/database/upgrades/27-ghost-extra-profile.sql
@@ -0,0 +1,2 @@
+-- v27 (compatible with v9+): Add column for extra ghost profile metadata
+ALTER TABLE ghost ADD COLUMN extra_profile jsonb;
diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go
index 9fa6569a..00ff01c9 100644
--- a/bridgev2/database/userlogin.go
+++ b/bridgev2/database/userlogin.go
@@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin {
func (u *UserLogin) sqlVariables() []any {
var remoteProfile dbutil.JSON
- if !u.RemoteProfile.IsEmpty() {
+ if !u.RemoteProfile.IsZero() {
remoteProfile.Data = &u.RemoteProfile
}
return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}}
diff --git a/bridgev2/errors.go b/bridgev2/errors.go
index e81b8953..f6677d2e 100644
--- a/bridgev2/errors.go
+++ b/bridgev2/errors.go
@@ -38,42 +38,47 @@ var ErrNotLoggedIn = errors.New("not logged in")
// but direct media is not enabled.
var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled")
+var ErrPortalIsDeleted = errors.New("portal is deleted")
+var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event")
+
// Common message status errors
var (
- ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
- ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false)
- ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
- ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
- ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
- ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
- ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
- ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
- ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
- ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
- ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
- ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
- ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
- ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
- ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
+ ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
+ ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false)
+ ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
+ ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
+ ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
+ ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
+ ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
+ ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
+ ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
+ ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
+ ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
+ ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
+ ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
+ ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
+ ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go
index 6cef6f06..590dd1dc 100644
--- a/bridgev2/ghost.go
+++ b/bridgev2/ghost.go
@@ -9,12 +9,15 @@ package bridgev2
import (
"context"
"crypto/sha256"
+ "encoding/json"
"fmt"
+ "maps"
"net/http"
+ "slices"
"github.com/rs/zerolog"
+ "go.mau.fi/util/exerrors"
"go.mau.fi/util/exmime"
- "golang.org/x/exp/slices"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -134,10 +137,11 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32
}
type UserInfo struct {
- Identifiers []string
- Name *string
- Avatar *Avatar
- IsBot *bool
+ Identifiers []string
+ Name *string
+ Avatar *Avatar
+ IsBot *bool
+ ExtraProfile database.ExtraProfile
ExtraUpdates ExtraUpdater[*Ghost]
}
@@ -185,9 +189,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
return true
}
-func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
+func (ghost *Ghost) getExtraProfileMeta() any {
bridgeName := ghost.Bridge.Network.GetName()
- return &event.BeeperProfileExtra{
+ baseExtra := &event.BeeperProfileExtra{
RemoteID: string(ghost.ID),
Identifiers: ghost.Identifiers,
Service: bridgeName.BeeperBridgeType,
@@ -195,23 +199,35 @@ func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
IsBridgeBot: false,
IsNetworkBot: ghost.IsBot,
}
+ if len(ghost.ExtraProfile) == 0 {
+ return baseExtra
+ }
+ mergedExtra := maps.Clone(ghost.ExtraProfile)
+ baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra))
+ exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra))
+ return mergedExtra
}
-func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool {
- if identifiers != nil {
- slices.Sort(identifiers)
- }
- if ghost.ContactInfoSet &&
- (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) &&
- (isBot == nil || *isBot == ghost.IsBot) {
+func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool {
+ if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta {
+ ghost.ContactInfoSet = false
return false
}
if identifiers != nil {
+ slices.Sort(identifiers)
+ }
+ changed := extraProfile.CopyTo(&ghost.ExtraProfile)
+ if identifiers != nil {
+ changed = changed || !slices.Equal(identifiers, ghost.Identifiers)
ghost.Identifiers = identifiers
}
if isBot != nil {
+ changed = changed || *isBot != ghost.IsBot
ghost.IsBot = *isBot
}
+ if ghost.ContactInfoSet && !changed {
+ return false
+ }
err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta())
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata")
@@ -234,7 +250,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool {
}
func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) {
- if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
+ if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
return
}
info, err := source.Client.GetUserInfo(ctx, ghost)
@@ -244,12 +260,16 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin
zerolog.Ctx(ctx).Debug().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
+ Bool("has_avatar", ghost.AvatarMXC != "").
+ Bool("avatar_set", ghost.AvatarSet).
Msg("Updating ghost info in IfNecessary call")
ghost.UpdateInfo(ctx, info)
} else {
zerolog.Ctx(ctx).Trace().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
+ Bool("has_avatar", ghost.AvatarMXC != "").
+ Bool("avatar_set", ghost.AvatarSet).
Msg("No ghost info received in IfNecessary call")
}
}
@@ -277,9 +297,14 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) {
}
if info.Avatar != nil {
update = ghost.UpdateAvatar(ctx, info.Avatar) || update
+ } else if oldAvatar == "" && !ghost.AvatarSet {
+ // Special case: nil avatar means we're not expecting one ever, if we don't currently have
+ // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary.
+ ghost.AvatarSet = true
+ update = true
}
- if info.Identifiers != nil || info.IsBot != nil {
- update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update
+ if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil {
+ update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update
}
if info.ExtraUpdates != nil {
update = info.ExtraUpdates(ctx, ghost) || update
diff --git a/bridgev2/login.go b/bridgev2/login.go
index 46dcf7da..b8321719 100644
--- a/bridgev2/login.go
+++ b/bridgev2/login.go
@@ -13,6 +13,7 @@ import (
"strings"
"maunium.net/go/mautrix/bridgev2/networkid"
+ "maunium.net/go/mautrix/event"
)
// LoginProcess represents a single occurrence of a user logging into the remote network.
@@ -179,6 +180,7 @@ const (
LoginInputFieldTypeURL LoginInputFieldType = "url"
LoginInputFieldTypeDomain LoginInputFieldType = "domain"
LoginInputFieldTypeSelect LoginInputFieldType = "select"
+ LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code"
)
type LoginInputDataField struct {
@@ -190,6 +192,8 @@ type LoginInputDataField struct {
Name string `json:"name"`
// The description of the field shown to the user.
Description string `json:"description"`
+ // A default value that the client can pre-fill the field with.
+ DefaultValue string `json:"default_value,omitempty"`
// A regex pattern that the client can use to validate input client-side.
Pattern string `json:"pattern,omitempty"`
// For fields of type select, the valid options.
@@ -269,6 +273,23 @@ func (f *LoginInputDataField) FillDefaultValidate() {
type LoginUserInputParams struct {
// The fields that the user needs to fill in.
Fields []LoginInputDataField `json:"fields"`
+
+ // Attachments to display alongside the input fields.
+ Attachments []*LoginUserInputAttachment `json:"attachments"`
+}
+
+type LoginUserInputAttachment struct {
+ Type event.MessageType `json:"type,omitempty"`
+ FileName string `json:"filename,omitempty"`
+ Content []byte `json:"content,omitempty"`
+ Info LoginUserInputAttachmentInfo `json:"info,omitempty"`
+}
+
+type LoginUserInputAttachmentInfo struct {
+ MimeType string `json:"mimetype,omitempty"`
+ Width int `json:"w,omitempty"`
+ Height int `json:"h,omitempty"`
+ Size int `json:"size,omitempty"`
}
type LoginCompleteParams struct {
diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go
index 3e05837f..5a2df953 100644
--- a/bridgev2/matrix/connector.go
+++ b/bridgev2/matrix/connector.go
@@ -81,6 +81,8 @@ type Connector struct {
MediaConfig mautrix.RespMediaConfig
SpecVersions *mautrix.RespVersions
+ SpecCaps *mautrix.RespCapabilities
+ specCapsLock sync.Mutex
Capabilities *bridgev2.MatrixCapabilities
IgnoreUnsupportedServer bool
@@ -142,16 +144,20 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) {
br.EventProcessor.On(event.EventReaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent)
+ br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent)
br.EventProcessor.On(event.StateMember, br.handleRoomEvent)
br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent)
+ br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent)
br.EventProcessor.On(event.StateTopic, br.handleRoomEvent)
br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent)
br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent)
br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent)
+ br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent)
br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent)
br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent)
+ br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent)
br.Bot = br.AS.BotIntent()
br.Crypto = NewCryptoHelper(br)
br.Bridge.Commands.(*commands.Processor).AddHandlers(
@@ -363,6 +369,8 @@ func (br *Connector) ensureConnection(ctx context.Context) {
br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites)
br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending)
br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange)
+ br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) ||
+ (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo)
break
}
}
@@ -407,6 +415,21 @@ func (br *Connector) ensureConnection(ctx context.Context) {
br.Bot.EnsureAppserviceConnection(ctx)
}
+func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities {
+ br.specCapsLock.Lock()
+ defer br.specCapsLock.Unlock()
+ if br.SpecCaps != nil {
+ return br.SpecCaps
+ }
+ caps, err := br.Bot.Capabilities(ctx)
+ if err != nil {
+ br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver")
+ return nil
+ }
+ br.SpecCaps = caps
+ return caps
+}
+
func (br *Connector) fetchMediaConfig(ctx context.Context) {
cfg, err := br.Bot.GetMediaConfig(ctx)
if err != nil {
@@ -621,7 +644,7 @@ func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventT
}
}
}
- return br.Bot.FullStateEvent(ctx, roomID, eventType, "")
+ return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey)
}
func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go
index f4a2e9a0..7f18f1f5 100644
--- a/bridgev2/matrix/crypto.go
+++ b/bridgev2/matrix/crypto.go
@@ -38,9 +38,9 @@ func init() {
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
-var NoSessionFound = crypto.NoSessionFound
-var DuplicateMessageIndex = crypto.DuplicateMessageIndex
-var UnknownMessageIndex = olm.UnknownMessageIndex
+var NoSessionFound = crypto.ErrNoSessionFound
+var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
+var UnknownMessageIndex = olm.ErrUnknownMessageIndex
type CryptoHelper struct {
bridge *Connector
@@ -439,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
var encrypted *event.EncryptedEventContent
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
- if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
+ if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
return
}
helper.log.Debug().Err(err).
diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go
index 1f82f77f..f7254bd4 100644
--- a/bridgev2/matrix/intent.go
+++ b/bridgev2/matrix/intent.go
@@ -9,6 +9,7 @@ package matrix
import (
"bytes"
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -27,6 +28,7 @@ import (
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
"maunium.net/go/mautrix/crypto/attachment"
+ "maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
@@ -43,6 +45,7 @@ type ASIntent struct {
var _ bridgev2.MatrixAPI = (*ASIntent)(nil)
var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil)
+var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil)
func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) {
if extra == nil {
@@ -56,7 +59,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
Extra: content.Raw,
})
}
- if eventType != event.EventReaction && eventType != event.EventRedaction {
+ if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction {
msgContent, ok := content.Parsed.(*event.MessageEventContent)
if ok {
msgContent.AddPerMessageProfileFallback()
@@ -84,6 +87,21 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()})
}
+func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) {
+ if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
+ return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
+ }
+ if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
+ return nil, fmt.Errorf("failed to check if room is encrypted: %w", err)
+ } else if encrypted && as.Connector.Crypto != nil {
+ if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil {
+ return nil, err
+ }
+ eventType = event.EventEncrypted
+ }
+ return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID})
+}
+
func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) {
targetContent, ok := content.Parsed.(*event.MemberEventContent)
if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" {
@@ -403,6 +421,7 @@ func (as *ASIntent) UploadMediaStream(
removeAndClose(replFile)
removeAndClose(tempFile)
}
+ req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
startedAsyncUpload = true
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
@@ -435,6 +454,7 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn
as.Connector.uploadSema.Release(int64(len(req.ContentBytes)))
}
}
+ req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
if resp != nil {
@@ -466,11 +486,62 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr
return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL)
}
-func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
- if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
- return nil
+func dataToFields(data any) (map[string]json.RawMessage, error) {
+ fields, ok := data.(map[string]json.RawMessage)
+ if ok {
+ return fields, nil
}
- return as.Matrix.BeeperUpdateProfile(ctx, data)
+ d, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+ d = canonicaljson.CanonicalJSONAssumeValid(d)
+ err = json.Unmarshal(d, &fields)
+ return fields, err
+}
+
+func marshalField(val any) json.RawMessage {
+ data, _ := json.Marshal(val)
+ if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
+ return canonicaljson.CanonicalJSONAssumeValid(data)
+ }
+ return data
+}
+
+var nullJSON = json.RawMessage("null")
+
+func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
+ if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
+ return as.Matrix.BeeperUpdateProfile(ctx, data)
+ } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo {
+ fields, err := dataToFields(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal fields: %w", err)
+ }
+ currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID)
+ if err != nil {
+ return fmt.Errorf("failed to get current profile: %w", err)
+ }
+ for key, val := range fields {
+ existing, ok := currentProfile.Extra[key]
+ if !ok {
+ if bytes.Equal(val, nullJSON) {
+ continue
+ }
+ err = as.Matrix.SetProfileField(ctx, key, val)
+ } else if !bytes.Equal(marshalField(existing), val) {
+ if bytes.Equal(val, nullJSON) {
+ err = as.Matrix.DeleteProfileField(ctx, key)
+ } else {
+ err = as.Matrix.SetProfileField(ctx, key, val)
+ }
+ }
+ if err != nil {
+ return fmt.Errorf("failed to set profile field %q: %w", key, err)
+ }
+ }
+ }
+ return nil
}
func (as *ASIntent) GetMXID() id.UserID {
@@ -512,6 +583,39 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent {
return content
}
+func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) {
+ if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
+ // Hungryserv doesn't override the capabilities endpoint nor do room versions
+ return
+ }
+ caps := as.Connector.fetchCapabilities(ctx)
+ roomVer := req.RoomVersion
+ if roomVer == "" && caps != nil && caps.RoomVersions != nil {
+ roomVer = id.RoomVersion(caps.RoomVersions.Default)
+ }
+ if roomVer != "" && !roomVer.PrivilegedRoomCreators() {
+ return
+ }
+ creators, _ := req.CreationContent["additional_creators"].([]id.UserID)
+ creators = append(slices.Clone(creators), as.GetMXID())
+ if req.PowerLevelOverride != nil {
+ for _, creator := range creators {
+ delete(req.PowerLevelOverride.Users, creator)
+ }
+ }
+ for _, evt := range req.InitialState {
+ if evt.Type != event.StatePowerLevels {
+ continue
+ }
+ content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
+ if ok {
+ for _, creator := range creators {
+ delete(content.Users, creator)
+ }
+ }
+ }
+}
+
func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) {
if as.Connector.Config.Encryption.Default {
req.InitialState = append(req.InitialState, &event.Event{
@@ -527,6 +631,7 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom)
}
req.CreationContent["m.federate"] = false
}
+ as.filterCreateRequestForV12(ctx, req)
resp, err := as.Matrix.CreateRoom(ctx, req)
if err != nil {
return "", err
@@ -680,10 +785,10 @@ func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.E
}
if evt.Type == event.EventEncrypted {
- if as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt {
+ if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt {
return nil, errors.New("can't decrypt the event")
}
- return as.Matrix.Crypto.Decrypt(ctx, evt)
+ return as.Connector.Crypto.Decrypt(ctx, evt)
}
return evt, nil
diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go
index 6c94bccc..954d0ad9 100644
--- a/bridgev2/matrix/matrix.go
+++ b/bridgev2/matrix/matrix.go
@@ -68,6 +68,10 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event)
case event.EphemeralEventTyping:
typingContent := evt.Content.AsTyping()
typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser)
+ case event.BeeperEphemeralEventAIStream:
+ if br.shouldIgnoreEvent(evt) {
+ return
+ }
}
br.Bridge.QueueMatrixEvent(ctx, evt)
}
@@ -127,6 +131,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event,
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, requesting keys and waiting longer...")
+ //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)
@@ -230,7 +235,6 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event
go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount)
decrypted.Mautrix.CheckpointSent = true
decrypted.Mautrix.DecryptionDuration = duration
- decrypted.Mautrix.EventSource |= event.SourceDecrypted
br.EventProcessor.Dispatch(ctx, decrypted)
if errorEventID != nil && *errorEventID != "" {
_, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID)
diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go
index 0f6aa68c..f5e438de 100644
--- a/bridgev2/matrix/mxmain/dberror.go
+++ b/bridgev2/matrix/mxmain/dberror.go
@@ -66,7 +66,12 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s
} else if errors.Is(err, dbutil.ErrForeignTables) {
br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info")
} else if errors.Is(err, dbutil.ErrNotOwned) {
- br.Log.Info().Msg("Sharing the same database with different programs is not supported")
+ var noe dbutil.NotOwnedError
+ if errors.As(err, &noe) && noe.Owner == br.Name {
+ br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?")
+ } else {
+ br.Log.Info().Msg("Sharing the same database with different programs is not supported")
+ }
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
br.Log.Info().Msg("Downgrading the bridge is not supported")
}
diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go
new file mode 100644
index 00000000..1b4f1467
--- /dev/null
+++ b/bridgev2/matrix/mxmain/envconfig.go
@@ -0,0 +1,161 @@
+// Copyright (c) 2025 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mxmain
+
+import (
+ "fmt"
+ "iter"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+
+ "go.mau.fi/util/random"
+)
+
+var randomParseFilePrefix = random.String(16) + "READFILE:"
+
+func parseEnv(prefix string) iter.Seq2[[]string, string] {
+ return func(yield func([]string, string) bool) {
+ for _, s := range os.Environ() {
+ if !strings.HasPrefix(s, prefix) {
+ continue
+ }
+ kv := strings.SplitN(s, "=", 2)
+ key := strings.TrimPrefix(kv[0], prefix)
+ value := kv[1]
+ if strings.HasSuffix(key, "_FILE") {
+ key = strings.TrimSuffix(key, "_FILE")
+ value = randomParseFilePrefix + value
+ }
+ key = strings.ToLower(key)
+ if !strings.ContainsRune(key, '.') {
+ key = strings.ReplaceAll(key, "__", ".")
+ }
+ if !yield(strings.Split(key, "."), value) {
+ return
+ }
+ }
+ }
+}
+
+func reflectYAMLFieldName(f *reflect.StructField) string {
+ parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2)
+ fieldName := parts[0]
+ if fieldName == "-" && len(parts) == 1 {
+ return ""
+ }
+ if fieldName == "" {
+ return strings.ToLower(f.Name)
+ }
+ return fieldName
+}
+
+type reflectGetResult struct {
+ val reflect.Value
+ valKind reflect.Kind
+ remainingPath []string
+}
+
+func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) {
+ if len(path) == 0 {
+ return &reflectGetResult{val: rv, valKind: rv.Kind()}, true
+ }
+ if rv.Kind() == reflect.Ptr {
+ rv = rv.Elem()
+ }
+ switch rv.Kind() {
+ case reflect.Map:
+ return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true
+ case reflect.Struct:
+ fields := reflect.VisibleFields(rv.Type())
+ for _, field := range fields {
+ fieldName := reflectYAMLFieldName(&field)
+ if fieldName != "" && fieldName == path[0] {
+ return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:])
+ }
+ }
+ default:
+ }
+ return nil, false
+}
+
+func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) {
+ if len(path) > 0 && path[0] == "network" {
+ return reflectGetYAML(network, path[1:])
+ }
+ return reflectGetYAML(main, path)
+}
+
+func formatKeyString(key []string) string {
+ return strings.Join(key, "->")
+}
+
+func UpdateConfigFromEnv(cfg, networkData any, prefix string) error {
+ cfgVal := reflect.ValueOf(cfg)
+ networkVal := reflect.ValueOf(networkData)
+ for key, value := range parseEnv(prefix) {
+ field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key)
+ if !ok {
+ return fmt.Errorf("%s not found", formatKeyString(key))
+ }
+ if strings.HasPrefix(value, randomParseFilePrefix) {
+ filepath := strings.TrimPrefix(value, randomParseFilePrefix)
+ fileData, err := os.ReadFile(filepath)
+ if err != nil {
+ return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err)
+ }
+ value = strings.TrimSpace(string(fileData))
+ }
+ var parsedVal any
+ var err error
+ switch field.valKind {
+ case reflect.String:
+ parsedVal = value
+ case reflect.Bool:
+ parsedVal, err = strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ parsedVal, err = strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ parsedVal, err = strconv.ParseUint(value, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ case reflect.Float32, reflect.Float64:
+ parsedVal, err = strconv.ParseFloat(value, 64)
+ if err != nil {
+ return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
+ }
+ default:
+ return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key))
+ }
+ if field.val.Kind() == reflect.Ptr {
+ if field.val.IsNil() {
+ field.val.Set(reflect.New(field.val.Type().Elem()))
+ }
+ field.val = field.val.Elem()
+ }
+ if field.val.Kind() == reflect.Map {
+ key = key[:len(key)-len(field.remainingPath)]
+ mapKeyStr := strings.Join(field.remainingPath, ".")
+ key = append(key, mapKeyStr)
+ if field.val.Type().Key().Kind() != reflect.String {
+ return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key))
+ }
+ field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal))
+ } else {
+ field.val.Set(reflect.ValueOf(parsedVal))
+ }
+ }
+ return nil
+}
diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml
index 27c3aa67..ccc81c4b 100644
--- a/bridgev2/matrix/mxmain/example-config.yaml
+++ b/bridgev2/matrix/mxmain/example-config.yaml
@@ -29,6 +29,9 @@ bridge:
# How long after an unknown error should the bridge attempt a full reconnect?
# Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value.
unknown_error_auto_reconnect: null
+ # Maximum number of times to do the auto-reconnect above.
+ # The counter is per login, but is never reset except on logout and restart.
+ unknown_error_max_auto_reconnects: 10
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
bridge_matrix_leave: false
@@ -241,6 +244,9 @@ matrix:
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
# rather than keeping the whole file in memory.
upload_file_threshold: 5242880
+ # Should the bridge set additional custom profile info for ghosts?
+ # This can make a lot of requests, as there's no batch profile update endpoint.
+ ghost_extra_profile_info: false
# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors.
analytics:
@@ -378,6 +384,8 @@ encryption:
# Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861).
# Changing this option requires updating the appservice registration file.
msc4190: false
+ # Whether to encrypt reactions and reply metadata as per MSC4392.
+ msc4392: false
# Should the bridge bot generate a recovery key and cross-signing keys and verify itself?
# Note that without the latest version of MSC4190, this will fail if you reset the bridge database.
# The generated recovery key will be saved in the kv_store table under `recovery_key`.
@@ -444,6 +452,16 @@ encryption:
# You should not enable this option unless you understand all the implications.
disable_device_change_key_rotation: false
+# Prefix for environment variables. All variables with this prefix must map to valid config fields.
+# Nesting in variable names is represented with a dot (.).
+# If there are no dots in the name, two underscores (__) are replaced with a dot.
+#
+# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token.
+# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily.
+#
+# If this is null, reading config fields from environment will be disabled.
+env_config_prefix: null
+
# Logging config. See https://github.com/tulir/zeroconfig for details.
logging:
min_level: debug
diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go
index c8eb820b..97cdeddf 100644
--- a/bridgev2/matrix/mxmain/legacymigrate.go
+++ b/bridgev2/matrix/mxmain/legacymigrate.go
@@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB(
}
var dbVersion int
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
- if dbVersion < expectedVersion {
+ if err != nil {
+ log.Fatal().Err(err).Msg("Failed to get database version")
+ return
+ } else if dbVersion < expectedVersion {
log.Fatal().
Int("expected_version", expectedVersion).
Int("version", dbVersion).
diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go
index ca0ca5f7..1e8b51d1 100644
--- a/bridgev2/matrix/mxmain/main.go
+++ b/bridgev2/matrix/mxmain/main.go
@@ -354,6 +354,13 @@ func (br *BridgeMain) LoadConfig() {
}
}
cfg.Bridge.Backfill = cfg.Backfill
+ if cfg.EnvConfigPrefix != "" {
+ err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix)
+ if err != nil {
+ _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err)
+ os.Exit(10)
+ }
+ }
br.Config = &cfg
}
diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go
index 43d19380..243b91da 100644
--- a/bridgev2/matrix/provisioning.go
+++ b/bridgev2/matrix/provisioning.go
@@ -85,10 +85,9 @@ const (
provisioningUserKey provisioningContextKey = iota
provisioningUserLoginKey
provisioningLoginProcessKey
+ ProvisioningKeyRequest
)
-const ProvisioningKeyRequest = "fi.mau.provision.request"
-
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}
@@ -97,12 +96,7 @@ func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
return prov.Router
}
-type IProvisioningAPI interface {
- GetRouter() *http.ServeMux
- GetUser(r *http.Request) *bridgev2.User
-}
-
-func (br *Connector) GetProvisioning() IProvisioningAPI {
+func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI {
return br.Provisioning
}
@@ -330,7 +324,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) {
prevState.UserID = ""
prevState.RemoteID = ""
prevState.RemoteName = ""
- prevState.RemoteProfile = nil
+ prevState.RemoteProfile = status.RemoteProfile{}
resp.Logins[i] = RespWhoamiLogin{
StateEvent: prevState.StateEvent,
StateTS: prevState.Timestamp,
@@ -409,10 +403,18 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
Override: overrideLogin,
}
prov.loginsLock.Unlock()
+ zerolog.Ctx(r.Context()).Info().
+ Any("first_step", firstStep).
+ Msg("Created login process")
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
}
func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) {
+ zerolog.Ctx(ctx).Info().
+ Str("step_id", step.StepID).
+ Str("user_login_id", string(step.CompleteParams.UserLoginID)).
+ Msg("Login completed successfully")
+ prov.deleteLogin(login, false)
if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID {
return
}
@@ -426,6 +428,15 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
}, bridgev2.DeleteOpts{LogoutRemote: true})
}
+func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) {
+ if cancel {
+ login.Process.Cancel()
+ }
+ prov.loginsLock.Lock()
+ delete(prov.logins, login.ID)
+ prov.loginsLock.Unlock()
+}
+
func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) {
loginID := r.PathValue("loginProcessID")
prov.loginsLock.RLock()
@@ -496,11 +507,14 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input")
RespondWithError(w, err, "Internal error submitting input")
+ prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
+ } else {
+ zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
@@ -514,11 +528,14 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait")
RespondWithError(w, err, "Internal error waiting for login")
+ prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
+ } else {
+ zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml
index 50b73c66..26068db4 100644
--- a/bridgev2/matrix/provisioning.yaml
+++ b/bridgev2/matrix/provisioning.yaml
@@ -728,15 +728,53 @@ components:
description: A more detailed description of the field shown to the user.
examples:
- Include the country code with a +
+ default_value:
+ type: string
+ description: A default value that the client can pre-fill the field with.
pattern:
type: string
format: regex
description: A regular expression that the field value must match.
- select:
+ options:
type: array
description: For fields of type select, the valid options.
items:
type: string
+ attachments:
+ type: array
+ description: A list of media attachments to show the user alongside the form fields.
+ items:
+ type: object
+ description: A media attachment to show the user.
+ required: [ type, filename, content ]
+ properties:
+ type:
+ type: string
+ description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported.
+ enum: [ m.image, m.audio ]
+ filename:
+ type: string
+ description: The filename for the media attachment.
+ content:
+ type: string
+ description: The raw file content for the attachment encoded in base64.
+ info:
+ type: object
+ description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted.
+ properties:
+ mimetype:
+ type: string
+ description: The MIME type for the media content.
+ examples: [ image/png, audio/mpeg ]
+ w:
+ type: number
+ description: The width of the media in pixels. Only applicable for images and videos.
+ h:
+ type: number
+ description: The height of the media in pixels. Only applicable for images and videos.
+ size:
+ type: number
+ description: The size of the media content in number of bytes. Strongly recommended to include.
- description: Cookie login step
required: [ type, cookies ]
properties:
diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go
index 07615daf..be26db49 100644
--- a/bridgev2/matrixinterface.go
+++ b/bridgev2/matrixinterface.go
@@ -14,6 +14,8 @@ import (
"os"
"time"
+ "go.mau.fi/util/exhttp"
+
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -26,6 +28,7 @@ type MatrixCapabilities struct {
AutoJoinInvites bool
BatchSending bool
ArbitraryMemberChange bool
+ ExtraProfileMeta bool
}
type MatrixConnector interface {
@@ -59,36 +62,54 @@ type MatrixConnector interface {
}
type MatrixConnectorWithArbitraryRoomState interface {
+ MatrixConnector
GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error)
}
type MatrixConnectorWithServer interface {
+ MatrixConnector
GetPublicAddress() string
GetRouter() *http.ServeMux
}
+type IProvisioningAPI interface {
+ GetRouter() *http.ServeMux
+ GetUser(r *http.Request) *User
+}
+
+type MatrixConnectorWithProvisioning interface {
+ MatrixConnector
+ GetProvisioning() IProvisioningAPI
+}
+
type MatrixConnectorWithPublicMedia interface {
+ MatrixConnector
GetPublicMediaAddress(contentURI id.ContentURIString) string
GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error)
}
type MatrixConnectorWithNameDisambiguation interface {
+ MatrixConnector
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
}
type MatrixConnectorWithBridgeIdentifier interface {
+ MatrixConnector
GetUniqueBridgeID() string
}
type MatrixConnectorWithURLPreviews interface {
+ MatrixConnector
GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error)
}
type MatrixConnectorWithPostRoomBridgeHandling interface {
+ MatrixConnector
HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error
}
type MatrixConnectorWithAnalytics interface {
+ MatrixConnector
TrackAnalytics(userID id.UserID, event string, properties map[string]any)
}
@@ -103,9 +124,15 @@ type DirectNotificationData struct {
}
type MatrixConnectorWithNotifications interface {
+ MatrixConnector
DisplayNotification(ctx context.Context, data *DirectNotificationData)
}
+type MatrixConnectorWithHTTPSettings interface {
+ MatrixConnector
+ GetHTTPClientSettings() exhttp.ClientSettings
+}
+
type MatrixSendExtra struct {
Timestamp time.Time
MessageMeta *database.Message
@@ -183,9 +210,16 @@ type MatrixAPI interface {
}
type StreamOrderReadingMatrixAPI interface {
+ MatrixAPI
MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error
}
type MarkAsDMMatrixAPI interface {
+ MatrixAPI
MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error
}
+
+type EphemeralSendingMatrixAPI interface {
+ MatrixAPI
+ BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error)
+}
diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go
index b8a5aec6..75c00cb0 100644
--- a/bridgev2/matrixinvite.go
+++ b/bridgev2/matrixinvite.go
@@ -88,6 +88,36 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI,
rejectInvite(ctx, evt, intent, "")
}
+func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) {
+ if portal.MXID == "" {
+ return
+ }
+ log := zerolog.Ctx(ctx)
+ existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID)
+ if err != nil {
+ log.Err(err).
+ Stringer("old_portal_mxid", portal.MXID).
+ Msg("Failed to check existing portal members, deleting room")
+ } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok {
+ log.Debug().
+ Stringer("old_portal_mxid", portal.MXID).
+ Msg("Inviter has no member event in old portal, deleting room")
+ } else if targetUserMember.Membership.IsInviteOrJoin() {
+ return
+ } else {
+ log.Debug().
+ Stringer("old_portal_mxid", portal.MXID).
+ Str("membership", string(targetUserMember.Membership)).
+ Msg("Inviter is not in old portal, deleting room")
+ }
+
+ if err = portal.RemoveMXID(ctx); err != nil {
+ log.Err(err).Msg("Failed to delete old portal mxid")
+ } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
+ log.Err(err).Msg("Failed to clean up old portal room")
+ }
+}
+
func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult {
ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey()))
validator, ok := br.Network.(IdentifierValidatingNetwork)
@@ -165,34 +195,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
}
- if portal.MXID != "" {
- doCleanup := true
- existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID)
- if err != nil {
- log.Err(err).
- Stringer("old_portal_mxid", portal.MXID).
- Msg("Failed to check existing portal members, deleting room")
- } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok {
- log.Debug().
- Stringer("old_portal_mxid", portal.MXID).
- Msg("Inviter has no member event in old portal, deleting room")
- } else if targetUserMember.Membership.IsInviteOrJoin() {
- doCleanup = false
- } else {
- log.Debug().
- Stringer("old_portal_mxid", portal.MXID).
- Str("membership", string(targetUserMember.Membership)).
- Msg("Inviter is not in old portal, deleting room")
- }
-
- if doCleanup {
- if err = portal.RemoveMXID(ctx); err != nil {
- log.Err(err).Msg("Failed to delete old portal mxid")
- } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
- log.Err(err).Msg("Failed to clean up old portal room")
- }
- }
- }
+ portal.CleanupOrphanedDM(ctx, sender.MXID)
err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID())
if err != nil {
log.Err(err).Msg("Failed to ensure bot is invited to room")
diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go
index 7118649d..df0c9e4d 100644
--- a/bridgev2/messagestatus.go
+++ b/bridgev2/messagestatus.go
@@ -20,6 +20,7 @@ import (
type MessageStatusEventInfo struct {
RoomID id.RoomID
+ TransactionID string
SourceEventID id.EventID
NewEventID id.EventID
EventType event.Type
@@ -41,6 +42,7 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo {
return &MessageStatusEventInfo{
RoomID: evt.RoomID,
+ TransactionID: evt.Unsigned.TransactionID,
SourceEventID: evt.ID,
EventType: evt.Type,
MessageType: evt.Content.AsMessage().MsgType,
@@ -182,9 +184,10 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe
Type: event.RelReference,
EventID: evt.SourceEventID,
},
- Status: ms.Status,
- Reason: ms.ErrorReason,
- Message: ms.Message,
+ TargetTxnID: evt.TransactionID,
+ Status: ms.Status,
+ Reason: ms.ErrorReason,
+ Message: ms.Message,
}
if ms.InternalError != nil {
content.InternalError = ms.InternalError.Error()
diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go
index 9bbcf897..efc5f100 100644
--- a/bridgev2/networkinterface.go
+++ b/bridgev2/networkinterface.go
@@ -261,6 +261,7 @@ type NetworkConnector interface {
}
type StoppableNetwork interface {
+ NetworkConnector
// Stop is called when the bridge is stopping, after all network clients have been disconnected.
Stop()
}
@@ -295,11 +296,6 @@ type PortalBridgeInfoFillingNetwork interface {
FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent)
}
-type PersonalFilteringCustomizingNetworkAPI interface {
- NetworkAPI
- CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom)
-}
-
// ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields
// before the bridge is started.
//
@@ -322,6 +318,16 @@ type MaxFileSizeingNetwork interface {
SetMaxFileSize(maxSize int64)
}
+type NetworkResettingNetwork interface {
+ NetworkConnector
+ // ResetHTTPTransport should recreate the HTTP client used by the bridge.
+ // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable.
+ ResetHTTPTransport()
+ // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections.
+ // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here.
+ ResetNetworkConnections()
+}
+
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
type MatrixMessageResponse struct {
@@ -712,6 +718,19 @@ type DeleteChatHandlingNetworkAPI interface {
HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error
}
+// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors
+// can implement to accept message requests from the remote network.
+type MessageRequestAcceptingNetworkAPI interface {
+ NetworkAPI
+ // HandleMatrixAcceptMessageRequest is called when the user accepts a message request.
+ HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error
+}
+
+type BeeperAIStreamHandlingNetworkAPI interface {
+ NetworkAPI
+ HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error
+}
+
type ResolveIdentifierResponse struct {
// Ghost is the ghost of the user that the identifier resolves to.
// This field should be set whenever possible. However, it is not required,
@@ -784,6 +803,16 @@ type UserSearchingNetworkAPI interface {
SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error)
}
+type GroupCreatingNetworkAPI interface {
+ IdentifierResolvingNetworkAPI
+ CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error)
+}
+
+type PersonalFilteringCustomizingNetworkAPI interface {
+ NetworkAPI
+ CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom)
+}
+
type ProvisioningCapabilities struct {
ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"`
GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"`
@@ -855,11 +884,6 @@ type GroupCreateParams struct {
RoomID id.RoomID `json:"room_id,omitempty"`
}
-type GroupCreatingNetworkAPI interface {
- IdentifierResolvingNetworkAPI
- CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error)
-}
-
type MembershipChangeType struct {
From event.Membership
To event.Membership
@@ -897,16 +921,15 @@ type MatrixMembershipChange struct {
MatrixRoomMeta[*event.MemberEventContent]
Target GhostOrUserLogin
Type MembershipChangeType
+}
- // Deprecated: Use Target instead
- TargetGhost *Ghost
- // Deprecated: Use Target instead
- TargetUserLogin *UserLogin
+type MatrixMembershipResult struct {
+ RedirectTo networkid.UserID
}
type MembershipHandlingNetworkAPI interface {
NetworkAPI
- HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error)
+ HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error)
}
type SinglePowerLevelChange struct {
@@ -1382,7 +1405,8 @@ type MatrixMessageRemove struct {
type MatrixRoomMeta[ContentType any] struct {
MatrixEventBase[ContentType]
- PrevContent ContentType
+ PrevContent ContentType
+ IsStateRequest bool
}
type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent]
@@ -1419,6 +1443,8 @@ type MatrixViewingChat struct {
}
type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent]
+type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent]
+type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent]
type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent]
type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent]
type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent]
diff --git a/bridgev2/portal.go b/bridgev2/portal.go
index b664c8f6..16aa703b 100644
--- a/bridgev2/portal.go
+++ b/bridgev2/portal.go
@@ -86,14 +86,15 @@ type Portal struct {
lastCapUpdate time.Time
- roomCreateLock sync.Mutex
- RoomCreated *exsync.Event
+ roomCreateLock sync.Mutex
+ cancelRoomCreate atomic.Pointer[context.CancelFunc]
+ RoomCreated *exsync.Event
functionalMembersLock sync.Mutex
functionalMembersCache *event.ElementFunctionalMembersContent
events chan portalEvent
- deleted bool
+ deleted *exsync.Event
eventsLock sync.Mutex
eventIdx int
@@ -127,6 +128,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage),
RoomCreated: exsync.NewEvent(),
+ deleted: exsync.NewEvent(),
}
if portal.MXID != "" {
portal.RoomCreated.Set()
@@ -167,7 +169,9 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que
}
func (portal *Portal) updateLogger() {
- logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID))
+ logWith := portal.Bridge.Log.With().
+ Str("portal_id", string(portal.ID)).
+ Str("portal_receiver", string(portal.Receiver))
if portal.MXID != "" {
logWith = logWith.Stringer("portal_mxid", portal.MXID)
}
@@ -335,6 +339,9 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port
}
func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult {
+ if portal.deleted.IsSet() {
+ return EventHandlingResultIgnored
+ }
if PortalEventBuffer == 0 {
portal.eventsLock.Lock()
defer portal.eventsLock.Unlock()
@@ -347,6 +354,8 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand
select {
case portal.events <- evt:
return EventHandlingResultQueued
+ case <-portal.deleted.GetChan():
+ return EventHandlingResultIgnored
default:
zerolog.Ctx(ctx).Error().
Str("portal_id", string(portal.ID)).
@@ -371,17 +380,21 @@ func (portal *Portal) eventLoop() {
go portal.pendingMessageTimeoutLoop(ctx, cfg)
defer cancel()
}
- i := 0
- for rawEvt := range portal.events {
- if portal.deleted {
+ deleteCh := portal.deleted.GetChan()
+ for i := 0; ; i++ {
+ select {
+ case rawEvt := <-portal.events:
+ if rawEvt == nil {
+ return
+ }
+ if portal.Bridge.Config.AsyncEvents {
+ go portal.handleSingleEventWithDelayLogging(i, rawEvt)
+ } else {
+ portal.handleSingleEventWithDelayLogging(i, rawEvt)
+ }
+ case <-deleteCh:
return
}
- i++
- if portal.Bridge.Config.AsyncEvents {
- go portal.handleSingleEventWithDelayLogging(i, rawEvt)
- } else {
- portal.handleSingleEventWithDelayLogging(i, rawEvt)
- }
}
}
@@ -473,6 +486,11 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context {
logWith = logWith.Int64("remote_stream_order", remoteStreamOrder)
}
}
+ if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok {
+ if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() {
+ logWith = logWith.Time("remote_timestamp", remoteTimestamp)
+ }
+ }
case *portalCreateEvent:
return evt.ctx
}
@@ -512,7 +530,14 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
}()
switch evt := rawEvt.(type) {
case *portalMatrixEvent:
- res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt)
+ isStateRequest := evt.evt.Type == event.BeeperSendState
+ if isStateRequest {
+ if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil {
+ portal.sendErrorStatus(ctx, evt.evt, err)
+ return
+ }
+ }
+ res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest)
if res.SendMSS {
if res.Error != nil {
portal.sendErrorStatus(ctx, evt.evt, res.Error)
@@ -520,9 +545,21 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
portal.sendSuccessStatus(ctx, evt.evt, 0, "")
}
}
- if res.Error != nil && evt.evt.StateKey != nil {
+ if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil {
portal.revertRoomMeta(ctx, evt.evt)
}
+ if isStateRequest && res.Success && !res.SkipStateEcho {
+ portal.sendRoomMeta(
+ ctx,
+ evt.sender.DoublePuppet(ctx),
+ time.UnixMilli(evt.evt.Timestamp),
+ evt.evt.Type,
+ evt.evt.GetStateKey(),
+ evt.evt.Content.Parsed,
+ false,
+ evt.evt.Content.Raw,
+ )
+ }
case *portalRemoteEvent:
res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt)
case *portalCreateEvent:
@@ -534,18 +571,44 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal
}
}
+func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
+ content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent)
+ if !ok {
+ return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)
+ }
+ evt.Content = content.Content
+ evt.StateKey = &content.StateKey
+ evt.Type = event.Type{Type: content.Type, Class: event.StateEventType}
+ _ = evt.Content.ParseRaw(evt.Type)
+ mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ return fmt.Errorf("matrix connector doesn't support fetching state")
+ }
+ prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey())
+ if err != nil && !errors.Is(err, mautrix.MNotFound) {
+ return fmt.Errorf("failed to get prev event: %w", err)
+ } else if prevEvt != nil {
+ evt.Unsigned.PrevContent = &prevEvt.Content
+ evt.Unsigned.PrevSender = prevEvt.Sender
+ }
+ return nil
+}
+
func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) {
if portal.Receiver != "" {
login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver)
if err != nil {
return nil, nil, err
}
- if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() {
+ if login == nil {
+ return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn)
+ } else if !login.Client.IsLoggedIn() {
+ return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn)
+ } else if login.UserMXID != user.MXID {
if allowRelay && portal.Relay != nil {
return nil, nil, nil
}
- // TODO different error for this case?
- return nil, nil, ErrNotLoggedIn
+ return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID)
}
up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey)
return login, up, err
@@ -628,7 +691,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID,
var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"}
-func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
+func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
log := zerolog.Ctx(ctx)
if evt.Mautrix.EventSource&event.SourceEphemeral != 0 {
switch evt.Type {
@@ -636,6 +699,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
return portal.handleMatrixReceipts(ctx, evt)
case event.EphemeralEventTyping:
return portal.handleMatrixTyping(ctx, evt)
+ case event.BeeperEphemeralEventAIStream:
+ return portal.handleMatrixAIStream(ctx, sender, evt)
default:
return EventHandlingResultIgnored
}
@@ -660,6 +725,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
}
var origSender *OrigSender
if login == nil {
+ if isStateRequest {
+ return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest)
+ }
login = portal.Relay
origSender = &OrigSender{
User: sender,
@@ -730,13 +798,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
case event.EventRedaction:
return portal.handleMatrixRedaction(ctx, login, origSender, evt)
case event.StateRoomName:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName)
case event.StateTopic:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic)
case event.StateRoomAvatar:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar)
case event.StateBeeperDisappearingTimer:
- return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer)
+ return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer)
case event.StateEncryption:
// TODO?
return EventHandlingResultIgnored
@@ -747,11 +815,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *
case event.AccountDataBeeperMute:
return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute)
case event.StateMember:
- return portal.handleMatrixMembership(ctx, login, origSender, evt)
+ return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest)
case event.StatePowerLevels:
- return portal.handleMatrixPowerLevels(ctx, login, origSender, evt)
+ return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest)
case event.BeeperDeleteChat:
return portal.handleMatrixDeleteChat(ctx, login, origSender, evt)
+ case event.BeeperAcceptMessageRequest:
+ return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt)
default:
return EventHandlingResultIgnored
}
@@ -875,6 +945,50 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event)
return EventHandlingResultSuccess
}
+func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
+ log := zerolog.Ctx(ctx)
+ if sender == nil {
+ log.Error().Msg("Missing sender for Matrix AI stream event")
+ return EventHandlingResultIgnored
+ }
+ login, _, err := portal.FindPreferredLogin(ctx, sender, true)
+ if err != nil {
+ log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ var origSender *OrigSender
+ if login == nil {
+ if portal.Relay == nil {
+ return EventHandlingResultIgnored
+ }
+ login = portal.Relay
+ origSender = &OrigSender{
+ User: sender,
+ UserID: sender.MXID,
+ }
+ }
+ content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent)
+ if !ok {
+ log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
+ }
+ api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI)
+ if !ok {
+ return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported)
+ }
+ err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{
+ Event: evt,
+ Content: content,
+ Portal: portal,
+ OrigSender: origSender,
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to handle Matrix AI stream event")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ return EventHandlingResultSuccess.WithMSS()
+}
+
func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) {
for _, userID := range userIDs {
login, ok := portal.currentlyTypingLogins[userID]
@@ -1162,6 +1276,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin
}
}
+ err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps)
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to auto-accept message request on message")
+ // TODO stop processing?
+ }
+
var resp *MatrixMessageResponse
if msgContent != nil {
resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt)
@@ -1418,7 +1538,7 @@ func (portal *Portal) handleMatrixEdit(
return EventHandlingResultSuccess
}
-func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult {
+func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) {
log := zerolog.Ctx(ctx)
reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI)
if !ok {
@@ -1441,6 +1561,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
log.Warn().Msg("Reaction target message not found in database")
return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound))
}
+ caps := sender.Client.GetCapabilities(ctx, portal)
+ err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps)
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction")
+ // TODO stop processing?
+ }
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str("reaction_target_remote_id", string(reactionTarget.ID))
})
@@ -1463,6 +1589,31 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if portal.Bridge.Config.OutgoingMessageReID {
deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID)
}
+ defer func() {
+ // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction
+ if handleRes.Success {
+ portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
+ }
+ }()
+ removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) {
+ if !handleRes.Success {
+ return
+ }
+ _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
+ Parsed: &event.RedactionEventContent{
+ Redacts: oldReact.MXID,
+ },
+ }, nil)
+ if err != nil {
+ log.Err(err).Msg("Failed to remove old reaction")
+ }
+ if deleteDB {
+ err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact)
+ if err != nil {
+ log.Err(err).Msg("Failed to delete old reaction from database")
+ }
+ }
+ }
existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID)
if err != nil {
log.Err(err).Msg("Failed to check if reaction is a duplicate")
@@ -1474,14 +1625,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
return EventHandlingResultIgnored.WithEventID(deterministicID)
}
react.ReactionToOverride = existing
- _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
- Parsed: &event.RedactionEventContent{
- Redacts: existing.MXID,
- },
- }, nil)
- if err != nil {
- log.Err(err).Msg("Failed to remove old reaction")
- }
+ defer removeOutdatedReaction(existing, false)
}
react.PreHandleResp = &preResp
if preResp.MaxReactions > 0 {
@@ -1496,18 +1640,14 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
// Keep n-1 previous reactions and remove the rest
react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1]
for _, oldReaction := range allReactions[preResp.MaxReactions-1:] {
- _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{
- Parsed: &event.RedactionEventContent{
- Redacts: oldReaction.MXID,
- },
- }, nil)
- if err != nil {
- log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded")
- }
- err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction)
- if err != nil {
- log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded")
+ if existing != nil && oldReaction.EmojiID == existing.EmojiID {
+ // Don't double-delete on networks that only allow one emoji
+ continue
}
+ // Intentionally defer in a loop, there won't be that many items,
+ // and we want all of them to be done after this function completes successfully
+ //goland:noinspection GoDeferInLoop
+ defer removeOutdatedReaction(oldReaction, true)
}
}
}
@@ -1552,7 +1692,6 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi
if err != nil {
log.Err(err).Msg("Failed to save reaction to database")
}
- portal.sendSuccessStatus(ctx, evt, 0, deterministicID)
return EventHandlingResultSuccess.WithEventID(deterministicID)
}
@@ -1562,6 +1701,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
+ isStateRequest bool,
fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error),
) EventHandlingResult {
if evt.StateKey == nil || *evt.StateKey != "" {
@@ -1625,7 +1765,8 @@ func handleMatrixRoomMeta[APIType any, ContentType any](
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- PrevContent: prevContent,
+ IsStateRequest: isStateRequest,
+ PrevContent: prevContent,
})
if err != nil {
log.Err(err).Msg("Failed to handle Matrix room metadata")
@@ -1695,6 +1836,77 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos
}
}
+func (portal *Portal) handleMatrixAcceptMessageRequest(
+ ctx context.Context,
+ sender *UserLogin,
+ origSender *OrigSender,
+ evt *event.Event,
+) EventHandlingResult {
+ if origSender != nil {
+ return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser)
+ }
+ log := zerolog.Ctx(ctx)
+ content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent)
+ if !ok {
+ log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type")
+ return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed))
+ }
+ api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI)
+ if !ok {
+ return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported)
+ }
+ err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{
+ Event: evt,
+ Content: content,
+ Portal: portal,
+ })
+ if err != nil {
+ log.Err(err).Msg("Failed to handle Matrix accept message request")
+ return EventHandlingResultFailed.WithMSSError(err)
+ }
+ if portal.MessageRequest {
+ portal.MessageRequest = false
+ portal.UpdateBridgeInfo(ctx)
+ err = portal.Save(ctx)
+ if err != nil {
+ log.Err(err).Msg("Failed to save portal after accepting message request")
+ }
+ }
+ return EventHandlingResultSuccess.WithMSS()
+}
+
+func (portal *Portal) autoAcceptMessageRequest(
+ ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures,
+) error {
+ if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported {
+ return nil
+ }
+ mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI)
+ if !ok {
+ return nil
+ }
+ err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{
+ Event: evt,
+ Content: &event.BeeperAcceptMessageRequestEventContent{
+ IsImplicit: true,
+ },
+ Portal: portal,
+ OrigSender: origSender,
+ })
+ if err != nil {
+ return err
+ }
+ if portal.MessageRequest {
+ portal.MessageRequest = false
+ portal.UpdateBridgeInfo(ctx)
+ err = portal.Save(ctx)
+ if err != nil {
+ zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request")
+ }
+ }
+ return nil
+}
+
func (portal *Portal) handleMatrixDeleteChat(
ctx context.Context,
sender *UserLogin,
@@ -1752,6 +1964,7 @@ func (portal *Portal) handleMatrixMembership(
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
+ isStateRequest bool,
) EventHandlingResult {
if evt.StateKey == nil {
return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
@@ -1791,7 +2004,6 @@ func (portal *Portal) handleMatrixMembership(
return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent)
}
targetGhost, _ := target.(*Ghost)
- targetUserLogin, _ := target.(*UserLogin)
membershipChange := &MatrixMembershipChange{
MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{
MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{
@@ -1802,19 +2014,60 @@ func (portal *Portal) handleMatrixMembership(
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- PrevContent: prevContent,
+ IsStateRequest: isStateRequest,
+ PrevContent: prevContent,
},
- Target: target,
- TargetGhost: targetGhost,
- TargetUserLogin: targetUserLogin,
- Type: membershipChangeType,
+ Target: target,
+ Type: membershipChangeType,
}
- _, err = api.HandleMatrixMembership(ctx, membershipChange)
+ res, err := api.HandleMatrixMembership(ctx, membershipChange)
if err != nil {
log.Err(err).Msg("Failed to handle Matrix membership change")
return EventHandlingResultFailed.WithMSSError(err)
}
- return EventHandlingResultSuccess.WithMSS()
+ didRedirectInvite := membershipChangeType == Invite &&
+ targetGhost != nil &&
+ res != nil &&
+ res.RedirectTo != "" &&
+ res.RedirectTo != targetGhost.ID
+ if didRedirectInvite {
+ log.Debug().
+ Str("orig_id", string(targetGhost.ID)).
+ Str("redirect_id", string(res.RedirectTo)).
+ Msg("Invite was redirected to different ghost")
+ var redirectGhost *Ghost
+ redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo)
+ if err != nil {
+ log.Err(err).Msg("Failed to get redirect target ghost")
+ return EventHandlingResultFailed.WithError(err)
+ }
+ if !isStateRequest {
+ portal.sendRoomMeta(
+ ctx,
+ sender.User.DoublePuppet(ctx),
+ time.UnixMilli(evt.Timestamp),
+ event.StateMember,
+ evt.GetStateKey(),
+ &event.MemberEventContent{
+ Membership: event.MembershipLeave,
+ Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo),
+ },
+ true,
+ nil,
+ )
+ }
+ portal.sendRoomMeta(
+ ctx,
+ sender.User.DoublePuppet(ctx),
+ time.UnixMilli(evt.Timestamp),
+ event.StateMember,
+ redirectGhost.Intent.GetMXID().String(),
+ content,
+ false,
+ nil,
+ )
+ }
+ return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite)
}
func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange {
@@ -1839,6 +2092,7 @@ func (portal *Portal) handleMatrixPowerLevels(
sender *UserLogin,
origSender *OrigSender,
evt *event.Event,
+ isStateRequest bool,
) EventHandlingResult {
if evt.StateKey == nil || *evt.StateKey != "" {
return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey)
@@ -1880,7 +2134,8 @@ func (portal *Portal) handleMatrixPowerLevels(
InputTransactionID: portal.parseInputTransactionID(origSender, evt),
},
- PrevContent: prevContent,
+ IsStateRequest: isStateRequest,
+ PrevContent: prevContent,
},
Users: make(map[id.UserID]*UserPowerLevelChange),
Events: make(map[string]*SinglePowerLevelChange),
@@ -2334,7 +2589,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin,
}
func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) {
- if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID {
+ if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" {
return
}
ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
@@ -2508,7 +2763,7 @@ func (portal *Portal) getRelationMeta(
log.Err(err).Msg("Failed to get last thread message from database")
}
if prevThreadEvent == nil {
- prevThreadEvent = threadRoot
+ prevThreadEvent = ptr.Clone(threadRoot)
}
}
return
@@ -3446,7 +3701,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo
}
func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult {
- if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID {
+ if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") {
return EventHandlingResultIgnored
}
intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt)
@@ -3851,9 +4106,9 @@ type ChatInfo struct {
Disappear *database.DisappearingSetting
ParentID *networkid.PortalID
- UserLocal *UserLocalPortalInfo
-
- CanBackfill bool
+ UserLocal *UserLocalPortalInfo
+ MessageRequest *bool
+ CanBackfill bool
ExcludeChangesFromTimeline bool
@@ -3973,10 +4228,11 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) {
Creator: portal.Bridge.Bot.GetMXID(),
Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(),
Channel: event.BridgeInfoSection{
- ID: string(portal.ID),
- DisplayName: portal.Name,
- AvatarURL: portal.AvatarMXC,
- Receiver: string(portal.Receiver),
+ ID: string(portal.ID),
+ DisplayName: portal.Name,
+ AvatarURL: portal.AvatarMXC,
+ Receiver: string(portal.Receiver),
+ MessageRequest: portal.MessageRequest,
// TODO external URL?
},
BeeperRoomTypeV2: string(portal.RoomType),
@@ -4253,7 +4509,11 @@ func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool {
}
func (portal *Portal) roomIsPublic(ctx context.Context) bool {
- evt, err := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState).GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "")
+ mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState)
+ if !ok {
+ return false
+ }
+ evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "")
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public")
return false
@@ -4714,6 +4974,10 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us
portal.RoomType = *info.Type
}
}
+ if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest {
+ changed = true
+ portal.MessageRequest = *info.MessageRequest
+ }
if info.Members != nil && portal.MXID != "" && source != nil {
err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{})
if err != nil {
@@ -4755,6 +5019,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
}
return nil
}
+ if portal.deleted.IsSet() {
+ return ErrPortalIsDeleted
+ }
waiter := make(chan struct{})
closed := false
evt := &portalCreateEvent{
@@ -4772,7 +5039,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
if PortalEventBuffer == 0 {
go portal.queueEvent(ctx, evt)
} else {
- portal.events <- evt
+ select {
+ case portal.events <- evt:
+ case <-portal.deleted.GetChan():
+ return ErrPortalIsDeleted
+ }
}
select {
case <-ctx.Done():
@@ -4783,7 +5054,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i
}
func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error {
+ cancellableCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ portal.cancelRoomCreate.CompareAndSwap(nil, &cancel)
portal.roomCreateLock.Lock()
+ portal.cancelRoomCreate.Store(&cancel)
defer portal.roomCreateLock.Unlock()
if portal.MXID != "" {
if source != nil {
@@ -4794,6 +5069,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
log := zerolog.Ctx(ctx).With().
Str("action", "create matrix room").
Logger()
+ cancellableCtx = log.WithContext(cancellableCtx)
ctx = log.WithContext(ctx)
log.Info().Msg("Creating Matrix room")
@@ -4802,16 +5078,16 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
if info != nil {
log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info")
}
- info, err = source.Client.GetChatInfo(ctx, portal)
+ info, err = source.Client.GetChatInfo(cancellableCtx, portal)
if err != nil {
log.Err(err).Msg("Failed to update portal info for creation")
return err
}
}
- portal.UpdateInfo(ctx, info, source, nil, time.Time{})
- if ctx.Err() != nil {
- return ctx.Err()
+ portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{})
+ if cancellableCtx.Err() != nil {
+ return cancellableCtx.Err()
}
powerLevels := &event.PowerLevelsEventContent{
@@ -4824,7 +5100,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
portal.Bridge.Bot.GetMXID(): 9001,
},
}
- initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels)
+ initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels)
if err != nil {
log.Err(err).Msg("Failed to process participant list for portal creation")
return err
@@ -4839,7 +5115,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
IsDirect: portal.RoomType == database.RoomTypeDM,
PowerLevelOverride: powerLevels,
BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey),
- RoomVersion: id.RoomV11,
}
autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites
if autoJoinInvites {
@@ -4852,7 +5127,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
req.CreationContent["type"] = event.RoomTypeSpace
}
bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo()
- roomFeatures := source.Client.GetCapabilities(ctx, portal)
+ roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal)
portal.CapState = database.CapabilityState{
Source: source.ID,
ID: roomFeatures.GetID(),
@@ -4934,6 +5209,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
Content: event.Content{Parsed: info.JoinRule},
})
}
+ if cancellableCtx.Err() != nil {
+ return cancellableCtx.Err()
+ }
roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req)
if err != nil {
log.Err(err).Msg("Failed to create Matrix room")
@@ -4992,7 +5270,10 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo
}
}
portal.addToUserSpaces(ctx)
- if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background {
+ if info.CanBackfill &&
+ portal.Bridge.Config.Backfill.Enabled &&
+ portal.RoomType != database.RoomTypeSpace &&
+ !portal.Bridge.Background {
portal.doForwardBackfill(ctx, source, nil, backfillBundle)
}
return nil
@@ -5032,8 +5313,11 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) {
}
func (portal *Portal) Delete(ctx context.Context) error {
+ if portal.deleted.IsSet() {
+ return nil
+ }
portal.removeInPortalCache(ctx)
- err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
+ err := portal.safeDBDelete(ctx)
if err != nil {
return err
}
@@ -5043,6 +5327,15 @@ func (portal *Portal) Delete(ctx context.Context) error {
return nil
}
+func (portal *Portal) safeDBDelete(ctx context.Context) error {
+ err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey)
+ if err != nil {
+ return fmt.Errorf("failed to delete messages in portal: %w", err)
+ }
+ // TODO delete child portals?
+ return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
+}
+
func (portal *Portal) RemoveMXID(ctx context.Context) error {
if portal.MXID == "" {
return nil
@@ -5081,8 +5374,10 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) {
}
func (portal *Portal) unlockedDelete(ctx context.Context) error {
- // TODO delete child portals?
- err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey)
+ if portal.deleted.IsSet() {
+ return nil
+ }
+ err := portal.safeDBDelete(ctx)
if err != nil {
return err
}
@@ -5091,15 +5386,18 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error {
}
func (portal *Portal) unlockedDeleteCache() {
+ if portal.deleted.IsSet() {
+ return
+ }
delete(portal.Bridge.portalsByKey, portal.PortalKey)
if portal.MXID != "" {
delete(portal.Bridge.portalsByMXID, portal.MXID)
}
+ portal.deleted.Set()
if portal.events != nil {
// TODO there's a small risk of this racing with a queueEvent call
close(portal.events)
}
- portal.deleted = true
}
func (portal *Portal) Save(ctx context.Context) error {
@@ -5107,6 +5405,9 @@ func (portal *Portal) Save(ctx context.Context) error {
}
func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error {
+ if portal.Receiver != "" && relay.ID != portal.Receiver {
+ return fmt.Errorf("can't set non-receiver login as relay")
+ }
portal.Relay = relay
if relay == nil {
portal.RelayLoginID = ""
diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go
index 88503380..879f07ae 100644
--- a/bridgev2/portalbackfill.go
+++ b/bridgev2/portalbackfill.go
@@ -194,6 +194,9 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t
if err != nil {
log.Err(err).Msg("Failed to get last thread message")
return
+ } else if anchorMessage == nil {
+ log.Warn().Msg("No messages found in thread?")
+ return
}
resp := portal.fetchThreadBackfill(ctx, source, anchorMessage)
if resp != nil {
@@ -407,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
if reaction.Timestamp.IsZero() {
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
}
+ //lint:ignore SA4006 it's a todo
targetPart, ok := partMap[*reaction.TargetPart]
if !ok {
// TODO warning log and/or skip reaction?
diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go
index 749ee389..4c7e2447 100644
--- a/bridgev2/portalinternal.go
+++ b/bridgev2/portalinternal.go
@@ -49,6 +49,10 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any
(*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback)
}
+func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
+ return (*Portal)(portal).unwrapBeeperSendState(ctx, evt)
+}
+
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) {
(*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID)
}
@@ -61,8 +65,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
}
-func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
+func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest)
}
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -125,12 +129,12 @@ func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sende
return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt)
}
-func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt)
+func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest)
}
-func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
- return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt)
+func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
+ return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest)
}
func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -305,6 +309,10 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha
return (*Portal)(portal).updateOtherUser(ctx, members)
}
+func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool {
+ return (*Portal)(portal).roomIsPublic(ctx)
+}
+
func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error {
return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts)
}
diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go
index a25fe820..c976d97c 100644
--- a/bridgev2/portalreid.go
+++ b/bridgev2/portalreid.go
@@ -32,21 +32,40 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
if source == target {
return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same")
}
- log := zerolog.Ctx(ctx)
- log.Debug().Msg("Re-ID'ing portal")
+ log := zerolog.Ctx(ctx).With().
+ Str("action", "re-id portal").
+ Stringer("source_portal_key", source).
+ Stringer("target_portal_key", target).
+ Logger()
+ ctx = log.WithContext(ctx)
defer func() {
log.Debug().Msg("Finished handling portal re-ID")
}()
- br.cacheLock.Lock()
- defer br.cacheLock.Unlock()
- sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true)
+ acquireCacheLock := func() {
+ if !br.cacheLock.TryLock() {
+ log.Debug().Msg("Waiting for global cache lock")
+ br.cacheLock.Lock()
+ log.Debug().Msg("Acquired global cache lock after waiting")
+ } else {
+ log.Trace().Msg("Acquired global cache lock without waiting")
+ }
+ }
+ log.Debug().Msg("Re-ID'ing portal")
+ sourcePortal, err := br.GetExistingPortalByKey(ctx, source)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err)
} else if sourcePortal == nil {
log.Debug().Msg("Source portal not found, re-ID is no-op")
return ReIDResultNoOp, nil, nil
}
- sourcePortal.roomCreateLock.Lock()
+ if !sourcePortal.roomCreateLock.TryLock() {
+ if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
+ (*cancelCreate)()
+ }
+ log.Debug().Msg("Waiting for source portal room creation lock")
+ sourcePortal.roomCreateLock.Lock()
+ log.Debug().Msg("Acquired source portal room creation lock after waiting")
+ }
defer sourcePortal.roomCreateLock.Unlock()
if sourcePortal.MXID == "" {
log.Info().Msg("Source portal doesn't have Matrix room, deleting row")
@@ -59,22 +78,37 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Stringer("source_portal_mxid", sourcePortal.MXID)
})
+
+ acquireCacheLock()
targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true)
if err != nil {
+ br.cacheLock.Unlock()
return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err)
}
if targetPortal == nil {
log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal")
err = sourcePortal.unlockedReID(ctx, target)
+ br.cacheLock.Unlock()
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err)
}
return ReIDResultSourceReIDd, sourcePortal, nil
}
- targetPortal.roomCreateLock.Lock()
+ br.cacheLock.Unlock()
+
+ if !targetPortal.roomCreateLock.TryLock() {
+ if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
+ (*cancelCreate)()
+ }
+ log.Debug().Msg("Waiting for target portal room creation lock")
+ targetPortal.roomCreateLock.Lock()
+ log.Debug().Msg("Acquired target portal room creation lock after waiting")
+ }
defer targetPortal.roomCreateLock.Unlock()
if targetPortal.MXID == "" {
log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal")
+ acquireCacheLock()
+ defer br.cacheLock.Unlock()
err = targetPortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err)
@@ -89,6 +123,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
return c.Stringer("target_portal_mxid", targetPortal.MXID)
})
log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal")
+ sourcePortal.removeInPortalCache(ctx)
+ acquireCacheLock()
+ defer br.cacheLock.Unlock()
err = sourcePortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err)
@@ -96,7 +133,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
go func() {
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
Parsed: &event.TombstoneEventContent{
- Body: fmt.Sprintf("This room has been merged"),
+ Body: "This room has been merged",
ReplacementRoom: targetPortal.MXID,
},
}, time.Now())
diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go
index fbe0a513..72bacaff 100644
--- a/bridgev2/provisionutil/creategroup.go
+++ b/bridgev2/provisionutil/creategroup.go
@@ -32,6 +32,9 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev
if !ok {
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups"))
}
+ zerolog.Ctx(ctx).Debug().
+ Any("create_params", params).
+ Msg("Creating group chat on remote network")
caps := login.Bridge.Network.GetCapabilities()
typeSpec, validType := caps.Provisioning.GroupCreation[params.Type]
if !validType {
@@ -98,6 +101,9 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev
if resp.PortalKey.IsEmpty() {
return nil, ErrNoPortalKey
}
+ zerolog.Ctx(ctx).Debug().
+ Object("portal_key", resp.PortalKey).
+ Msg("Successfully created group on remote network")
if resp.Portal == nil {
resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey)
if err != nil {
diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go
index 5387347c..cfc388d0 100644
--- a/bridgev2/provisionutil/resolveidentifier.go
+++ b/bridgev2/provisionutil/resolveidentifier.go
@@ -109,6 +109,7 @@ func ResolveIdentifier(
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
}
}
+ resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID)
if createChat && resp.Chat.Portal.MXID == "" {
apiResp.JustCreated = true
err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo)
diff --git a/bridgev2/queue.go b/bridgev2/queue.go
index 308d03c5..3775c825 100644
--- a/bridgev2/queue.go
+++ b/bridgev2/queue.go
@@ -67,6 +67,7 @@ var (
ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage())
+ ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage()
)
func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult {
@@ -159,6 +160,8 @@ type EventHandlingResult struct {
Ignored bool
Queued bool
+ SkipStateEcho bool
+
// Error is an optional reason for failure. It is not required, Success may be false even without a specific error.
Error error
// Whether the Error should be sent as a MSS event.
@@ -194,6 +197,11 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult {
return ehr
}
+func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult {
+ ehr.SkipStateEcho = skip
+ return ehr
+}
+
func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult {
if err == nil {
return ehr
@@ -212,7 +220,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult {
return ul.Bridge.QueueRemoteEvent(ul, evt)
}
-func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) {
+func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult {
log := login.Log
ctx := log.WithContext(br.BackgroundCtx)
maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver)
@@ -228,14 +236,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res Event
if err != nil {
log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain).
Msg("Failed to get portal to handle remote event")
- return
+ return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err))
} else if portal == nil {
log.Warn().
Stringer("event_type", evt.GetType()).
Object("portal_key", key).
Bool("uncertain_receiver", isUncertain).
Msg("Portal not found to handle remote event")
- return
+ return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler)
}
// TODO put this in a better place, and maybe cache to avoid constant db queries
login.MarkInPortal(ctx, portal)
diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go
index 8aa91866..449a8773 100644
--- a/bridgev2/simplevent/meta.go
+++ b/bridgev2/simplevent/meta.go
@@ -101,6 +101,18 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E
return evt
}
+func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta {
+ origFunc := evt.LogContext
+ if origFunc == nil {
+ evt.LogContext = f
+ return evt
+ }
+ evt.LogContext = func(c zerolog.Context) zerolog.Context {
+ return f(origFunc(c))
+ }
+ return evt
+}
+
func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta {
evt.PortalKey = p
return evt
diff --git a/bridgev2/space.go b/bridgev2/space.go
index f6d07922..2ca2bce3 100644
--- a/bridgev2/space.go
+++ b/bridgev2/space.go
@@ -164,8 +164,7 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) {
ul.UserMXID: 50,
},
},
- RoomVersion: id.RoomV11,
- Invite: []id.UserID{ul.UserMXID},
+ Invite: []id.UserID{ul.UserMXID},
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{ul.UserMXID}
diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go
index 430d4c7c..5925dd4f 100644
--- a/bridgev2/status/bridgestate.go
+++ b/bridgev2/status/bridgestate.go
@@ -19,7 +19,6 @@ import (
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
- "go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/networkid"
@@ -112,7 +111,7 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile {
return other
}
-func (rp *RemoteProfile) IsEmpty() bool {
+func (rp *RemoteProfile) IsZero() bool {
return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil)
}
@@ -130,7 +129,7 @@ type BridgeState struct {
UserID id.UserID `json:"user_id,omitempty"`
RemoteID networkid.UserLoginID `json:"remote_id,omitempty"`
RemoteName string `json:"remote_name,omitempty"`
- RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"`
+ RemoteProfile RemoteProfile `json:"remote_profile,omitzero"`
Reason string `json:"reason,omitempty"`
Info map[string]interface{} `json:"info,omitempty"`
@@ -210,7 +209,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool {
pong.StateEvent == newPong.StateEvent &&
pong.RemoteName == newPong.RemoteName &&
pong.UserAction == newPong.UserAction &&
- ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) &&
+ pong.RemoteProfile == newPong.RemoteProfile &&
pong.Error == newPong.Error &&
maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) &&
pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now())
diff --git a/bridgev2/user.go b/bridgev2/user.go
index af9e9694..9a7896d6 100644
--- a/bridgev2/user.go
+++ b/bridgev2/user.go
@@ -229,9 +229,8 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) {
user.MXID: 50,
},
},
- RoomVersion: id.RoomV11,
- Invite: []id.UserID{user.MXID},
- IsDirect: true,
+ Invite: []id.UserID{user.MXID},
+ IsDirect: true,
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{user.MXID}
diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go
index b5fcfcd0..d56dc4cc 100644
--- a/bridgev2/userlogin.go
+++ b/bridgev2/userlogin.go
@@ -10,6 +10,7 @@ import (
"cmp"
"context"
"fmt"
+ "maps"
"slices"
"sync"
"time"
@@ -50,6 +51,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
+ // TODO if loading the user caused the provided userlogin to be loaded, cancel here?
+ // Currently this will double-load it
}
userLogin := &UserLogin{
UserLogin: dbUserLogin,
@@ -140,6 +143,12 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin {
return br.userLoginsByID[id]
}
+func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) {
+ br.cacheLock.Lock()
+ defer br.cacheLock.Unlock()
+ return slices.Collect(maps.Values(br.userLoginsByID))
+}
+
func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
@@ -503,7 +512,7 @@ func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeStat
state.UserID = ul.UserMXID
state.RemoteID = ul.ID
state.RemoteName = ul.RemoteName
- state.RemoteProfile = &ul.RemoteProfile
+ state.RemoteProfile = ul.RemoteProfile
filler, ok := ul.Client.(status.BridgeStateFiller)
if ok {
return filler.FillBridgeState(state)
diff --git a/client.go b/client.go
index d07bede5..7062d9b9 100644
--- a/client.go
+++ b/client.go
@@ -386,7 +386,14 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
}
}
if body := req.Context().Value(LogBodyContextKey); body != nil {
- evt.Interface("req_body", body)
+ switch typedLogBody := body.(type) {
+ case json.RawMessage:
+ evt.RawJSON("req_body", typedLogBody)
+ case string:
+ evt.Str("req_body", typedLogBody)
+ default:
+ panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body))
+ }
}
if errors.Is(err, context.Canceled) {
evt.Msg("Request canceled")
@@ -450,8 +457,10 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
}
if params.SensitiveContent && !logSensitiveContent {
logBody = ""
+ } else if len(jsonStr) > 32768 {
+ logBody = fmt.Sprintf("", len(jsonStr))
} else {
- logBody = params.RequestJSON
+ logBody = json.RawMessage(jsonStr)
}
reqBody = bytes.NewReader(jsonStr)
reqLen = int64(len(jsonStr))
@@ -476,7 +485,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
}
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
params.RequestJSON = struct{}{}
- logBody = params.RequestJSON
+ logBody = json.RawMessage("{}")
reqBody = bytes.NewReader([]byte("{}"))
reqLen = 2
}
@@ -614,7 +623,9 @@ func (cli *Client) doRetry(
select {
case <-time.After(backoff):
case <-req.Context().Done():
- return nil, nil, req.Context().Err()
+ if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) {
+ return nil, nil, req.Context().Err()
+ }
}
if cli.UpdateRequestOnRetry != nil {
req = cli.UpdateRequestOnRetry(req, cause)
@@ -740,12 +751,15 @@ func (cli *Client) executeCompiledRequest(
cli.RequestStart(req)
startTime := time.Now()
res, err := client.Do(req)
- duration := time.Now().Sub(startTime)
+ duration := time.Since(startTime)
if res != nil && !dontReadResponse {
defer res.Body.Close()
}
if err != nil {
- if retries > 0 && !errors.Is(err, context.Canceled) {
+ // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry
+ canRetry := !errors.Is(err, context.Canceled) ||
+ errors.Is(context.Cause(req.Context()), ErrContextCancelRetry)
+ if retries > 0 && canRetry {
return cli.doRetry(
req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
)
@@ -857,7 +871,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
}
start := time.Now()
_, err = cli.MakeFullRequest(ctx, fullReq)
- duration := time.Now().Sub(start)
+ duration := time.Since(start)
timeout := time.Duration(req.Timeout) * time.Millisecond
buffer := 10 * time.Second
if req.Since == "" {
@@ -904,7 +918,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp
return
}
-func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
+func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
var bodyBytes []byte
bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
@@ -928,7 +942,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (
// Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
//
// Registers with kind=user. For kind=guest, see RegisterGuest.
-func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
+func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
u := cli.BuildClientURL("v3", "register")
return cli.register(ctx, u, req)
}
@@ -937,7 +951,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegiste
// with kind=guest.
//
// For kind=user, see Register.
-func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
+func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
query := map[string]string{
"kind": "guest",
}
@@ -960,8 +974,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe
// panic(err)
// }
// token := res.AccessToken
-func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
- res, uia, err := cli.Register(ctx, req)
+func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) {
+ _, uia, err := cli.Register(ctx, req)
if err != nil && uia == nil {
return nil, err
} else if uia == nil {
@@ -970,7 +984,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe
return nil, errors.New("server does not support m.login.dummy")
}
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
- res, _, err = cli.Register(ctx, req)
+ res, _, err := cli.Register(ctx, req)
if err != nil {
return nil, err
}
@@ -1144,7 +1158,9 @@ func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit
}
func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) {
- if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) {
+ supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms)
+ supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms)
+ if cli.SpecVersions != nil && !supportsUnstable && !supportsStable {
err = fmt.Errorf("server does not support fetching mutual rooms")
return
}
@@ -1154,7 +1170,10 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex
if len(extras) > 0 {
query["from"] = extras[0].From
}
- urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
+ urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query)
+ if !supportsStable && supportsUnstable {
+ urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
+ }
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
@@ -1319,6 +1338,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
+ if req.UnstableStickyDuration > 0 {
+ queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
+ }
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
var isEncrypted bool
@@ -1342,6 +1364,48 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
return
}
+// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint.
+// contentJSON should be a value that can be encoded as JSON using json.Marshal.
+func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
+ var req ReqSendEvent
+ if len(extra) > 0 {
+ req = extra[0]
+ }
+
+ var txnID string
+ if len(req.TransactionID) > 0 {
+ txnID = req.TransactionID
+ } else {
+ txnID = cli.TxnID()
+ }
+
+ queryParams := map[string]string{}
+ if req.Timestamp > 0 {
+ queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
+ }
+
+ if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted {
+ var isEncrypted bool
+ isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
+ if err != nil {
+ err = fmt.Errorf("failed to check if room is encrypted: %w", err)
+ return
+ }
+ if isEncrypted {
+ if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
+ err = fmt.Errorf("failed to encrypt event: %w", err)
+ return
+ }
+ eventType = event.EventEncrypted
+ }
+ }
+
+ urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID}
+ urlPath := cli.BuildURLWithQuery(urlData, queryParams)
+ _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
+ return
+}
+
// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
@@ -1360,6 +1424,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
+ if req.UnstableStickyDuration > 0 {
+ queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
+ }
if req.Timestamp > 0 {
queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
}
@@ -1746,6 +1813,8 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any
return nil, nil
}
+type RoomStateMap = map[event.Type]map[string]*event.Event
+
// State gets all state in a room.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {
@@ -1828,6 +1897,9 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa
}
func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
+ if mxcURL.IsEmpty() {
+ return nil, fmt.Errorf("empty mxc uri provided to Download")
+ }
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
Method: http.MethodGet,
URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID),
@@ -1842,6 +1914,9 @@ type DownloadThumbnailExtra struct {
}
func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) {
+ if mxcURL.IsEmpty() {
+ return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail")
+ }
if len(extras) > 1 {
panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras)))
}
@@ -1914,10 +1989,15 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr
}
req.MXC = resp.ContentURI
req.UnstableUploadURL = resp.UnstableUploadURL
+ if req.AsyncContext == nil {
+ req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background())
+ }
go func() {
- _, err = cli.UploadMedia(ctx, req)
+ _, err = cli.UploadMedia(req.AsyncContext, req)
if err != nil {
- cli.Log.Error().Stringer("mxc", req.MXC).Err(err).Msg("Async upload of media failed")
+ zerolog.Ctx(req.AsyncContext).Err(err).
+ Stringer("mxc", req.MXC).
+ Msg("Async upload of media failed")
}
}()
return resp, nil
@@ -1953,6 +2033,7 @@ type ReqUploadMedia struct {
ContentType string
FileName string
+ AsyncContext context.Context
DoneCallback func()
// MXC specifies an existing MXC URI which doesn't have content yet to upload into.
@@ -1965,7 +2046,10 @@ type ReqUploadMedia struct {
}
func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) {
- cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
+ cli.Log.Debug().
+ Str("url", url).
+ Int64("content_length", contentLength).
+ Msg("Uploading media to external URL")
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content)
if err != nil {
return nil, err
@@ -2014,8 +2098,16 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*
Msg("Error uploading media to external URL, not retrying")
return nil, err
}
- cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err).
+ backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries)
+ cli.Log.Warn().Err(err).
+ Str("url", data.UnstableUploadURL).
+ Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Error uploading media to external URL, retrying")
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
retries--
_, err = readerSeeker.Seek(0, io.SeekStart)
if err != nil {
@@ -2595,13 +2687,13 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req
return err
}
-func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error {
+func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error {
urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err
}
-func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error {
+func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error {
urlPath := cli.BuildClientURL("v3", "delete_devices")
_, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil)
return err
@@ -2612,7 +2704,7 @@ type UIACallback = func(*RespUserInteractive) interface{}
// UploadCrossSigningKeys uploads the given cross-signing keys to the server.
// Because the endpoint requires user-interactive authentication a callback must be provided that,
// given the UI auth parameters, produces the required result (or nil to end the flow).
-func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error {
+func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error {
content, err := cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"),
@@ -2703,30 +2795,51 @@ func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespW
return
}
-// UnstableGetSuspendedStatus uses MSC4323 to check if a user is suspended.
-func (cli *Client) UnstableGetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) {
- urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID)
+func (cli *Client) makeMSC4323URL(action string, target id.UserID) string {
+ if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) {
+ return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target)
+ } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) {
+ return cli.BuildClientURL("v1", "admin", action, target)
+ }
+ return ""
+}
+
+// GetSuspendedStatus uses MSC4323 to check if a user is suspended.
+func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) {
+ urlPath := cli.makeMSC4323URL("suspend", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ }
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
return
}
-// UnstableGetLockStatus uses MSC4323 to check if a user is locked.
-func (cli *Client) UnstableGetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) {
- urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID)
+// GetLockStatus uses MSC4323 to check if a user is locked.
+func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) {
+ urlPath := cli.makeMSC4323URL("lock", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ }
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
return
}
-// UnstableSetSuspendedStatus uses MSC4323 to set whether a user account is suspended.
-func (cli *Client) UnstableSetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) {
- urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID)
+// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended.
+func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) {
+ urlPath := cli.makeMSC4323URL("suspend", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ }
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res)
return
}
-// UnstableSetLockStatus uses MSC4323 to set whether a user account is locked.
-func (cli *Client) UnstableSetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) {
- urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID)
+// SetLockStatus uses MSC4323 to set whether a user account is locked.
+func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) {
+ urlPath := cli.makeMSC4323URL("lock", userID)
+ if urlPath == "" {
+ return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
+ }
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res)
return
}
diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go
new file mode 100644
index 00000000..c2846427
--- /dev/null
+++ b/client_ephemeral_test.go
@@ -0,0 +1,158 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mautrix_test
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) {
+ roomID := id.RoomID("!room:example.com")
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+ txnID := "txn-123"
+
+ var gotPath string
+ var gotQueryTS string
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotQueryTS = r.URL.Query().Get("ts")
+ assert.Equal(t, http.MethodPut, r.Method)
+ _, _ = w.Write([]byte(`{"event_id":"$evt"}`))
+ }))
+ defer ts.Close()
+
+ cli, err := mautrix.NewClient(ts.URL, "", "")
+ require.NoError(t, err)
+
+ _, err = cli.BeeperSendEphemeralEvent(
+ context.Background(),
+ roomID,
+ evtType,
+ map[string]any{"foo": "bar"},
+ mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234},
+ )
+ require.NoError(t, err)
+
+ assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/"))
+ assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID))
+ assert.Equal(t, "1234", gotQueryTS)
+}
+
+func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`))
+ }))
+ defer ts.Close()
+
+ cli, err := mautrix.NewClient(ts.URL, "", "")
+ require.NoError(t, err)
+
+ _, err = cli.BeeperSendEphemeralEvent(
+ context.Background(),
+ id.RoomID("!room:example.com"),
+ event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType},
+ map[string]any{"foo": "bar"},
+ )
+ require.Error(t, err)
+ assert.True(t, errors.Is(err, mautrix.MUnrecognized))
+}
+
+func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) {
+ roomID := id.RoomID("!room:example.com")
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+ txnID := "txn-encrypted"
+
+ stateStore := mautrix.NewMemoryStateStore()
+ err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{
+ Algorithm: id.AlgorithmMegolmV1,
+ })
+ require.NoError(t, err)
+
+ fakeCrypto := &fakeCryptoHelper{
+ encryptedContent: &event.EncryptedEventContent{
+ Algorithm: id.AlgorithmMegolmV1,
+ MegolmCiphertext: []byte("ciphertext"),
+ },
+ }
+
+ var gotPath string
+ var gotBody map[string]any
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ assert.Equal(t, http.MethodPut, r.Method)
+ err := json.NewDecoder(r.Body).Decode(&gotBody)
+ require.NoError(t, err)
+ _, _ = w.Write([]byte(`{"event_id":"$evt"}`))
+ }))
+ defer ts.Close()
+
+ cli, err := mautrix.NewClient(ts.URL, "", "")
+ require.NoError(t, err)
+ cli.StateStore = stateStore
+ cli.Crypto = fakeCrypto
+
+ _, err = cli.BeeperSendEphemeralEvent(
+ context.Background(),
+ roomID,
+ evtType,
+ map[string]any{"foo": "bar"},
+ mautrix.ReqSendEvent{TransactionID: txnID},
+ )
+ require.NoError(t, err)
+
+ assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID))
+ assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"])
+ assert.Equal(t, 1, fakeCrypto.encryptCalls)
+ assert.Equal(t, roomID, fakeCrypto.lastRoomID)
+ assert.Equal(t, evtType, fakeCrypto.lastEventType)
+}
+
+type fakeCryptoHelper struct {
+ encryptCalls int
+ lastRoomID id.RoomID
+ lastEventType event.Type
+ lastEncryptInput any
+ encryptedContent *event.EncryptedEventContent
+}
+
+func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) {
+ f.encryptCalls++
+ f.lastRoomID = roomID
+ f.lastEventType = eventType
+ f.lastEncryptInput = content
+ return f.encryptedContent, nil
+}
+
+func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) {
+ return nil, nil
+}
+
+func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool {
+ return false
+}
+
+func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) {
+}
+
+func (f *fakeCryptoHelper) Init(context.Context) error {
+ return nil
+}
diff --git a/commands/container.go b/commands/container.go
index bc685b7b..9b909b75 100644
--- a/commands/container.go
+++ b/commands/container.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,14 +8,20 @@ package commands
import (
"fmt"
+ "slices"
"strings"
"sync"
+
+ "go.mau.fi/util/exmaps"
+
+ "maunium.net/go/mautrix/event/cmdschema"
)
type CommandContainer[MetaType any] struct {
commands map[string]*Handler[MetaType]
aliases map[string]string
lock sync.RWMutex
+ parent *Handler[MetaType]
}
func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
@@ -25,6 +31,29 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
}
}
+func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent {
+ data := make(exmaps.Set[*Handler[MetaType]])
+ cont.collectHandlers(data)
+ specs := make([]*cmdschema.EventContent, 0, data.Size())
+ for handler := range data.Iter() {
+ if handler.Parameters != nil {
+ specs = append(specs, handler.Spec())
+ }
+ }
+ return specs
+}
+
+func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) {
+ cont.lock.RLock()
+ defer cont.lock.RUnlock()
+ for _, handler := range cont.commands {
+ into.Add(handler)
+ if handler.subcommandContainer != nil {
+ handler.subcommandContainer.collectHandlers(into)
+ }
+ }
+}
+
// Register registers the given command handlers.
func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) {
if cont == nil {
@@ -32,7 +61,10 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType])
}
cont.lock.Lock()
defer cont.lock.Unlock()
- for _, handler := range handlers {
+ for i, handler := range handlers {
+ if handler == nil {
+ panic(fmt.Errorf("handler #%d is nil", i+1))
+ }
cont.registerOne(handler)
}
}
@@ -45,6 +77,10 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType])
} else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists {
panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget))
}
+ if !slices.Contains(handler.parents, cont.parent) {
+ handler.parents = append(handler.parents, cont.parent)
+ handler.nestedNameCache = nil
+ }
cont.commands[handler.Name] = handler
for _, alias := range handler.Aliases {
if strings.ToLower(alias) != alias {
diff --git a/commands/event.go b/commands/event.go
index 77a3c0d2..76d6c9f0 100644
--- a/commands/event.go
+++ b/commands/event.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,7 @@ package commands
import (
"context"
+ "encoding/json"
"fmt"
"strings"
@@ -35,6 +36,8 @@ type Event[MetaType any] struct {
// RawArgs is the same as args, but without the splitting by whitespace.
RawArgs string
+ StructuredArgs json.RawMessage
+
Ctx context.Context
Log *zerolog.Logger
Proc *Processor[MetaType]
@@ -61,7 +64,7 @@ var IDHTMLParser = &format.HTMLParser{
}
// ParseEvent parses a message into a command event struct.
-func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] {
+func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" {
return nil
@@ -70,12 +73,34 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta
if content.Format == event.FormatHTML {
text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx))
}
+ if content.MSC4391BotCommand != nil {
+ if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 {
+ return nil
+ }
+ wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand)
+ wrapped.RawInput = text
+ return wrapped
+ }
if len(text) == 0 {
return nil
}
return RawTextToEvent[MetaType](ctx, evt, text)
}
+func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] {
+ commandParts := strings.Split(content.Command, " ")
+ return &Event[MetaType]{
+ Event: evt,
+ // Fake a command and args to let the subcommand finder in Process work.
+ Command: commandParts[0],
+ Args: commandParts[1:],
+ Ctx: ctx,
+ Log: zerolog.Ctx(ctx),
+
+ StructuredArgs: content.Arguments,
+ }
+}
+
func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] {
parts := strings.Fields(text)
if len(parts) == 0 {
@@ -188,3 +213,25 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) {
evt.RawArgs = arg + " " + evt.RawArgs
evt.Args = append([]string{arg}, evt.Args...)
}
+
+func (evt *Event[MetaType]) ParseArgs(into any) error {
+ return json.Unmarshal(evt.StructuredArgs, into)
+}
+
+func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) {
+ err = evt.ParseArgs(&into)
+ return
+}
+
+func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) {
+ return func(evt *Event[MetaType]) {
+ parsed, err := ParseArgs[T, MetaType](evt)
+ if err != nil {
+ evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct")
+ // TODO better error, usage info? deduplicate with Process
+ evt.Reply("Failed to parse arguments: %v", err)
+ return
+ }
+ fn(evt, parsed)
+ }
+}
diff --git a/commands/handler.go b/commands/handler.go
index b01d594f..56f27f06 100644
--- a/commands/handler.go
+++ b/commands/handler.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,9 @@ package commands
import (
"strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/event/cmdschema"
)
type Handler[MetaType any] struct {
@@ -25,12 +28,63 @@ type Handler[MetaType any] struct {
// Event.ShiftArg will likely be useful for implementing such parameters.
PreFunc func(ce *Event[MetaType])
+ // Description is a short description of the command.
+ Description *event.ExtensibleTextContainer
+ // Parameters is a description of structured command parameters.
+ // If set, the StructuredArgs field of Event will be populated.
+ Parameters []*cmdschema.Parameter
+ TailParam string
+
+ parents []*Handler[MetaType]
+ nestedNameCache []string
subcommandContainer *CommandContainer[MetaType]
}
+func (h *Handler[MetaType]) NestedNames() []string {
+ if h.nestedNameCache != nil {
+ return h.nestedNameCache
+ }
+ nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents))
+ for _, parent := range h.parents {
+ if parent == nil {
+ nestedNames = append(nestedNames, h.Name)
+ nestedNames = append(nestedNames, h.Aliases...)
+ } else {
+ for _, parentName := range parent.NestedNames() {
+ nestedNames = append(nestedNames, parentName+" "+h.Name)
+ for _, alias := range h.Aliases {
+ nestedNames = append(nestedNames, parentName+" "+alias)
+ }
+ }
+ }
+ }
+ h.nestedNameCache = nestedNames
+ return nestedNames
+}
+
+func (h *Handler[MetaType]) Spec() *cmdschema.EventContent {
+ names := h.NestedNames()
+ return &cmdschema.EventContent{
+ Command: names[0],
+ Aliases: names[1:],
+ Parameters: h.Parameters,
+ Description: h.Description,
+ TailParam: h.TailParam,
+ }
+}
+
+func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) {
+ if h.Parameters == nil {
+ h.Parameters = other.Parameters
+ h.TailParam = other.TailParam
+ }
+ h.Func = other.Func
+}
+
func (h *Handler[MetaType]) initSubcommandContainer() {
if len(h.Subcommands) > 0 {
h.subcommandContainer = NewCommandContainer[MetaType]()
+ h.subcommandContainer.parent = h
h.subcommandContainer.Register(h.Subcommands...)
} else {
h.subcommandContainer = nil
diff --git a/commands/processor.go b/commands/processor.go
index 9341329b..80f6745d 100644
--- a/commands/processor.go
+++ b/commands/processor.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
case event.EventReaction:
parsed = proc.ParseReaction(ctx, evt)
case event.EventMessage:
- parsed = ParseEvent[MetaType](ctx, evt)
+ parsed = proc.ParseEvent(ctx, evt)
}
- if parsed == nil || !proc.PreValidator.Validate(parsed) {
+ if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) {
return
}
parsed.Proc = proc
@@ -107,6 +107,12 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
break
}
}
+ if parsed.StructuredArgs != nil && len(parsed.Args) > 0 {
+ // TODO allow unknown command handlers to be called?
+ // The client sent MSC4391 data, but the target command wasn't found
+ log.Debug().Msg("Didn't find handler for MSC4391 command")
+ return
+ }
logWith := log.With().
Str("command", parsed.Command).
@@ -116,11 +122,31 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
}
if proc.LogArgs {
logWith = logWith.Strs("args", parsed.Args)
+ if parsed.StructuredArgs != nil {
+ logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs)
+ }
}
log = logWith.Logger()
parsed.Ctx = log.WithContext(ctx)
parsed.Log = &log
+ if handler.Parameters != nil && parsed.StructuredArgs == nil {
+ // The handler wants structured parameters, but the client didn't send MSC4391 data
+ var err error
+ parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs)
+ if err != nil {
+ log.Debug().Err(err).Msg("Failed to parse structured arguments")
+ // TODO better error, usage info? deduplicate with WithParsedArgs
+ parsed.Reply("Failed to parse arguments: %v", err)
+ return
+ }
+ if proc.LogArgs {
+ log.UpdateContext(func(c zerolog.Context) zerolog.Context {
+ return c.RawJSON("structured_args", parsed.StructuredArgs)
+ })
+ }
+ }
+
log.Debug().Msg("Processing command")
handler.Func(parsed)
}
diff --git a/commands/reactions.go b/commands/reactions.go
index 0df372e5..0d316219 100644
--- a/commands/reactions.go
+++ b/commands/reactions.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2025 Tulir Asokan
+// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,7 @@ package commands
import (
"context"
+ "encoding/json"
"strings"
"github.com/rs/zerolog"
@@ -19,6 +20,11 @@ import (
const ReactionCommandsKey = "fi.mau.reaction_commands"
const ReactionMultiUseKey = "fi.mau.reaction_multi_use"
+type ReactionCommandData struct {
+ Command string `json:"command"`
+ Args any `json:"args,omitempty"`
+}
+
func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.ReactionEventContent)
if !ok {
@@ -67,21 +73,33 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E
Msg("Reaction command not found in target event")
return nil
}
- cmdString, ok := rawCmd.(string)
- if !ok {
+ var wrappedEvt *Event[MetaType]
+ switch typedCmd := rawCmd.(type) {
+ case string:
+ wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd)
+ case map[string]any:
+ var input event.MSC4391BotCommandInput
+ if marshaled, err := json.Marshal(typedCmd); err != nil {
+
+ } else if err = json.Unmarshal(marshaled, &input); err != nil {
+
+ } else {
+ wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input)
+ }
+ }
+ if wrappedEvt == nil {
zerolog.Ctx(ctx).Debug().
Stringer("target_event_id", evtID).
Str("reaction_key", content.RelatesTo.Key).
Msg("Reaction command data is invalid")
return nil
}
- wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString)
wrappedEvt.Proc = proc
wrappedEvt.Redact()
if !isMultiUse {
DeleteAllReactions(ctx, proc.Client, evt)
}
- if cmdString == "" {
+ if wrappedEvt.Command == "" {
return nil
}
return wrappedEvt
diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go
index 155cca5c..727aacbf 100644
--- a/crypto/attachment/attachments.go
+++ b/crypto/attachment/attachments.go
@@ -21,13 +21,24 @@ import (
)
var (
- HashMismatch = errors.New("mismatching SHA-256 digest")
- UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
- UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
- InvalidKey = errors.New("failed to decode key")
- InvalidInitVector = errors.New("failed to decode initialization vector")
- InvalidHash = errors.New("failed to decode SHA-256 hash")
- ReaderClosed = errors.New("encrypting reader was already closed")
+ ErrHashMismatch = errors.New("mismatching SHA-256 digest")
+ ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
+ ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
+ ErrInvalidKey = errors.New("failed to decode key")
+ ErrInvalidInitVector = errors.New("failed to decode initialization vector")
+ ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
+ ErrReaderClosed = errors.New("encrypting reader was already closed")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ HashMismatch = ErrHashMismatch
+ UnsupportedVersion = ErrUnsupportedVersion
+ UnsupportedAlgorithm = ErrUnsupportedAlgorithm
+ InvalidKey = ErrInvalidKey
+ InvalidInitVector = ErrInvalidInitVector
+ InvalidHash = ErrInvalidHash
+ ReaderClosed = ErrReaderClosed
)
var (
@@ -85,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
if ef.decoded != nil {
return nil
} else if len(ef.Key.Key) != keyBase64Length {
- return InvalidKey
+ return ErrInvalidKey
} else if len(ef.InitVector) != ivBase64Length {
- return InvalidInitVector
+ return ErrInvalidInitVector
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
- return InvalidHash
+ return ErrInvalidHash
}
ef.decoded = &decodedKeys{}
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
if err != nil {
- return InvalidKey
+ return ErrInvalidKey
}
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
if err != nil {
- return InvalidInitVector
+ return ErrInvalidInitVector
}
if includeHash {
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
if err != nil {
- return InvalidHash
+ return ErrInvalidHash
}
}
return nil
@@ -179,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
if r.closed {
- return 0, ReaderClosed
+ return 0, ErrReaderClosed
}
if offset != 0 || whence != io.SeekStart {
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
@@ -200,7 +211,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
if r.closed {
- return 0, ReaderClosed
+ return 0, ErrReaderClosed
} else if r.isDecrypting && r.file.decoded == nil {
if err = r.file.PrepareForDecryption(); err != nil {
return
@@ -224,7 +235,7 @@ func (r *encryptingReader) Close() (err error) {
}
if r.isDecrypting {
if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
- return HashMismatch
+ return ErrHashMismatch
}
} else {
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
@@ -265,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
func (ef *EncryptedFile) PrepareForDecryption() error {
if ef.Version != "v2" {
- return UnsupportedVersion
+ return ErrUnsupportedVersion
} else if ef.Key.Algorithm != "A256CTR" {
- return UnsupportedAlgorithm
+ return ErrUnsupportedAlgorithm
} else if err := ef.decodeKeys(true); err != nil {
return err
}
@@ -281,7 +292,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
}
dataHash := sha256.Sum256(data)
if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
- return HashMismatch
+ return ErrHashMismatch
}
utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
return nil
diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go
index d7f1394a..9fe929ab 100644
--- a/crypto/attachment/attachments_test.go
+++ b/crypto/attachment/attachments_test.go
@@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
file := parseHelloWorld()
file.Version = "foo"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, UnsupportedVersion)
+ assert.ErrorIs(t, err, ErrUnsupportedVersion)
}
func TestUnsupportedAlgorithm(t *testing.T) {
file := parseHelloWorld()
file.Key.Algorithm = "bar"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, UnsupportedAlgorithm)
+ assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
}
func TestHashMismatch(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, HashMismatch)
+ assert.ErrorIs(t, err, ErrHashMismatch)
}
func TestTooLongHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, InvalidHash)
+ assert.ErrorIs(t, err, ErrInvalidHash)
}
func TestTooShortHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "5/Gy1JftyyQ"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
- assert.ErrorIs(t, err, InvalidHash)
+ assert.ErrorIs(t, err, ErrInvalidHash)
}
diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go
index 4094f695..5d9bf5b3 100644
--- a/crypto/cross_sign_key.go
+++ b/crypto/cross_sign_key.go
@@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
}
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
- err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
+ err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{
Master: masterKey,
SelfSigning: selfKey,
UserSigning: userKey,
diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go
index f85d1ea3..223fc7b5 100644
--- a/crypto/cross_sign_pubkey.go
+++ b/crypto/cross_sign_pubkey.go
@@ -63,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
if len(dbKeys) > 0 {
masterKey, ok := dbKeys[id.XSUsageMaster]
if ok {
- selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
- userSigning, _ := dbKeys[id.XSUsageUserSigning]
+ selfSigning := dbKeys[id.XSUsageSelfSigning]
+ userSigning := dbKeys[id.XSUsageUserSigning]
return &CrossSigningPublicKeysCache{
MasterKey: masterKey.Key,
SelfSigningKey: selfSigning.Key,
diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go
index 50b58ea0..fd42880d 100644
--- a/crypto/cross_sign_ssss.go
+++ b/crypto/cross_sign_ssss.go
@@ -8,6 +8,7 @@ package crypto
import (
"context"
+ "errors"
"fmt"
"maunium.net/go/mautrix"
@@ -77,7 +78,11 @@ func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey s
return fmt.Errorf("failed to get default SSSS key data: %w", err)
}
key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey)
- if err != nil {
+ if errors.Is(err, ssss.ErrUnverifiableKey) {
+ mach.machOrContextLog(ctx).Warn().
+ Str("key_id", keyID).
+ Msg("SSSS key is unverifiable, trying to use without verifying")
+ } else if err != nil {
return err
}
err = mach.FetchCrossSigningKeysFromSSSS(ctx, key)
diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go
index d30b7e32..57406b11 100644
--- a/crypto/cross_sign_store.go
+++ b/crypto/cross_sign_store.go
@@ -26,24 +26,22 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
- if currentKeys != nil {
- for curKeyUsage, curKey := range currentKeys {
- log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
- // got a new key with the same usage as an existing key
- for _, newKeyUsage := range userKeys.Usage {
- if newKeyUsage == curKeyUsage {
- if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
- // old key is not in the new key map, so we drop signatures made by it
- if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
- log.Error().Err(err).Msg("Error deleting old signatures made by user")
- } else {
- log.Debug().
- Int64("signature_count", count).
- Msg("Dropped signatures made by old key as it has been replaced")
- }
+ for curKeyUsage, curKey := range currentKeys {
+ log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
+ // got a new key with the same usage as an existing key
+ for _, newKeyUsage := range userKeys.Usage {
+ if newKeyUsage == curKeyUsage {
+ if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
+ // old key is not in the new key map, so we drop signatures made by it
+ if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
+ log.Error().Err(err).Msg("Error deleting old signatures made by user")
+ } else {
+ log.Debug().
+ Int64("signature_count", count).
+ Msg("Dropped signatures made by old key as it has been replaced")
}
- break
}
+ break
}
}
}
diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go
index 1939ea79..b62dc128 100644
--- a/crypto/cryptohelper/cryptohelper.go
+++ b/crypto/cryptohelper/cryptohelper.go
@@ -278,7 +278,7 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
}
}
-var NoSessionFound = crypto.NoSessionFound
+var NoSessionFound = crypto.ErrNoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
@@ -371,6 +371,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
+ //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@@ -418,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
defer helper.lock.RUnlock()
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
if err != nil {
- if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
+ if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
return
}
helper.log.Debug().
diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go
index 47279474..457d5a0c 100644
--- a/crypto/decryptmegolm.go
+++ b/crypto/decryptmegolm.go
@@ -24,13 +24,23 @@ import (
)
var (
- IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
- NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
- DuplicateMessageIndex = errors.New("duplicate megolm message index")
- WrongRoom = errors.New("encrypted megolm event is not intended for this room")
- DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
- SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
- RatchetError = errors.New("failed to ratchet session after use")
+ ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
+ ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
+ ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
+ ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
+ ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
+ ErrRatchetError = errors.New("failed to ratchet session after use")
+ ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
+ NoSessionFound = ErrNoSessionFound
+ DuplicateMessageIndex = ErrDuplicateMessageIndex
+ WrongRoom = ErrWrongRoom
+ DeviceKeyMismatch = ErrDeviceKeyMismatch
+ RatchetError = ErrRatchetError
)
type megolmEvent struct {
@@ -45,13 +55,30 @@ var (
relatesToTopLevelPath = exgjson.Path("content", "m.relates_to")
)
+const sessionIDLength = 43
+
+func validateCiphertextCharacters(ciphertext []byte) bool {
+ for _, b := range ciphertext {
+ if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' {
+ return false
+ }
+ }
+ return true
+}
+
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
- return nil, IncorrectEncryptedContentType
+ return nil, ErrIncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
- return nil, UnsupportedAlgorithm
+ return nil, ErrUnsupportedAlgorithm
+ } else if len(content.MegolmCiphertext) < 74 {
+ return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext))
+ } else if len(content.SessionID) != sessionIDLength {
+ return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID))
+ } else if !validateCiphertextCharacters(content.MegolmCiphertext) {
+ return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload)
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
@@ -97,7 +124,13 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
- return nil, DeviceKeyMismatch
+ log.Debug().
+ Stringer("session_sender_key", sess.SenderKey).
+ Stringer("device_sender_key", device.IdentityKey).
+ Stringer("session_signing_key", sess.SigningKey).
+ Stringer("device_signing_key", device.SigningKey).
+ Msg("Device keys don't match keys in session, marking as untrusted")
+ trustLevel = id.TrustStateDeviceKeyMismatch
} else {
trustLevel, err = mach.ResolveTrustContext(ctx, device)
if err != nil {
@@ -147,7 +180,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
if err != nil {
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
} else if megolmEvt.RoomID != encryptionRoomID {
- return nil, WrongRoom
+ return nil, ErrWrongRoom
}
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
megolmEvt.Type.Class = event.StateEventType
@@ -180,6 +213,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
TrustSource: device,
ForwardedKeys: forwardedKeys,
WasEncrypted: true,
+ EventSource: evt.Mautrix.EventSource | event.SourceDecrypted,
ReceivedAt: evt.Mautrix.ReceivedAt,
},
}, nil
@@ -201,19 +235,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
if decodeErr != nil {
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
- return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
+ return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
- return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
} else if !ok {
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
- return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
}
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
- return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
+ return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
}
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
@@ -224,13 +258,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
- return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
- } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
- return sess, nil, 0, SenderKeyMismatch
+ return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
- if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
+ if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
}
@@ -238,7 +270,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
- return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
+ return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
}
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
@@ -290,24 +322,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
- return sess, plaintext, messageIndex, RatchetError
+ return sess, plaintext, messageIndex, ErrRatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}
diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go
index 30cc4cfe..aea5e6dc 100644
--- a/crypto/decryptolm.go
+++ b/crypto/decryptolm.go
@@ -26,15 +26,27 @@ import (
)
var (
- UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
- NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
- UnsupportedOlmMessageType = errors.New("unsupported olm message type")
- DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
- DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
- SenderMismatch = errors.New("mismatched sender in olm payload")
- RecipientMismatch = errors.New("mismatched recipient in olm payload")
- RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
- ErrDuplicateMessage = errors.New("duplicate olm message")
+ ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
+ ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
+ ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
+ ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
+ ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
+ ErrSenderMismatch = errors.New("mismatched sender in olm payload")
+ ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
+ ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
+ ErrDuplicateMessage = errors.New("duplicate olm message")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ UnsupportedAlgorithm = ErrUnsupportedAlgorithm
+ NotEncryptedForMe = ErrNotEncryptedForMe
+ UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
+ DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
+ DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
+ SenderMismatch = ErrSenderMismatch
+ RecipientMismatch = ErrRecipientMismatch
+ RecipientKeyMismatch = ErrRecipientKeyMismatch
)
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
@@ -56,13 +68,13 @@ type DecryptedOlmEvent struct {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
- return nil, IncorrectEncryptedContentType
+ return nil, ErrIncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmOlmV1 {
- return nil, UnsupportedAlgorithm
+ return nil, ErrUnsupportedAlgorithm
}
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
if !ok {
- return nil, NotEncryptedForMe
+ return nil, ErrNotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
@@ -78,7 +90,7 @@ type OlmEventKeys struct {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
- return nil, UnsupportedOlmMessageType
+ return nil, ErrUnsupportedOlmMessageType
}
log := mach.machOrContextLog(ctx).With().
@@ -102,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
olmEvt.Type.Class = evt.Type.Class
if evt.Sender != olmEvt.Sender {
- return nil, SenderMismatch
+ return nil, ErrSenderMismatch
} else if mach.Client.UserID != olmEvt.Recipient {
- return nil, RecipientMismatch
+ return nil, ErrRecipientMismatch
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
- return nil, RecipientKeyMismatch
+ return nil, ErrRecipientKeyMismatch
}
if len(olmEvt.Content.VeryRaw) > 0 {
@@ -122,6 +134,9 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
func olmMessageHash(ciphertext string) ([32]byte, error) {
+ if ciphertext == "" {
+ return [32]byte{}, fmt.Errorf("empty ciphertext")
+ }
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
return sha256.Sum256(ciphertextBytes), err
}
@@ -151,7 +166,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
if err != nil {
- if err == DecryptionFailedWithMatchingSession {
+ if err == ErrDecryptionFailedWithMatchingSession {
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
@@ -169,10 +184,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(log, sender, senderKey)
- return nil, DecryptionFailedForNormalMessage
+ return nil, ErrDecryptionFailedForNormalMessage
}
- accountBackup, err := mach.account.Internal.Pickle([]byte("tmp"))
+ accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
@@ -302,7 +317,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
- return nil, DecryptionFailedWithMatchingSession
+ return nil, ErrDecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
@@ -345,7 +360,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
ctx := log.WithContext(mach.backgroundCtx)
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
- delta := time.Now().Sub(prevUnwedge)
+ delta := time.Since(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
log.Debug().
Str("previous_recreation", delta.String()).
diff --git a/crypto/devicelist.go b/crypto/devicelist.go
index 61a22522..f0d2b129 100644
--- a/crypto/devicelist.go
+++ b/crypto/devicelist.go
@@ -22,14 +22,23 @@ import (
)
var (
- MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
- MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
- MismatchingSigningKey = errors.New("received update for device with different signing key")
- NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
- NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
- InvalidKeySignature = errors.New("invalid signature on device keys")
+ ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
+ ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
+ ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
+ ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
+ ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
+ ErrInvalidKeySignature = errors.New("invalid signature on device keys")
+ ErrUserNotTracked = errors.New("user is not tracked")
+)
- ErrUserNotTracked = errors.New("user is not tracked")
+// Deprecated: use variables prefixed with Err
+var (
+ MismatchingDeviceID = ErrMismatchingDeviceID
+ MismatchingUserID = ErrMismatchingUserID
+ MismatchingSigningKey = ErrMismatchingSigningKey
+ NoSigningKeyFound = ErrNoSigningKeyFound
+ NoIdentityKeyFound = ErrNoIdentityKeyFound
+ InvalidKeySignature = ErrInvalidKeySignature
)
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
@@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
if deviceID != deviceKeys.DeviceID {
- return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
+ return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
} else if userID != deviceKeys.UserID {
- return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
+ return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
}
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
if signingKey == "" {
- return nil, NoSigningKeyFound
+ return nil, ErrNoSigningKeyFound
} else if identityKey == "" {
- return nil, NoIdentityKeyFound
+ return nil, ErrNoIdentityKeyFound
}
if existing != nil && existing.SigningKey != signingKey {
- return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
+ return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
}
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
if err != nil {
return existing, fmt.Errorf("failed to verify signature: %w", err)
} else if !ok {
- return existing, InvalidKeySignature
+ return existing, ErrInvalidKeySignature
}
name, ok := deviceKeys.Unsigned["device_display_name"].(string)
diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go
index b3d19618..88f9c8d4 100644
--- a/crypto/encryptmegolm.go
+++ b/crypto/encryptmegolm.go
@@ -25,7 +25,12 @@ import (
)
var (
- NoGroupSession = errors.New("no group session created")
+ ErrNoGroupSession = errors.New("no group session created")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ NoGroupSession = ErrNoGroupSession
)
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
@@ -82,15 +87,20 @@ type rawMegolmEvent struct {
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
func IsShareError(err error) bool {
- return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
+ return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
}
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
+ if len(ciphertext) == 0 {
+ return 0, fmt.Errorf("empty ciphertext")
+ }
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
+ } else if len(decoded) < 2+binary.MaxVarintLen64 {
+ return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded))
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
@@ -120,7 +130,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
- return nil, NoGroupSession
+ return nil, ErrNoGroupSession
}
plaintext, err := json.Marshal(&rawMegolmEvent{
RoomID: roomID,
@@ -164,6 +174,15 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}
+ if mach.MSC4392Relations && encrypted.RelatesTo != nil {
+ // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content.
+ // Other relations like threads are still left unencrypted.
+ encrypted.RelatesTo.InReplyTo = nil
+ encrypted.RelatesTo.IsFallingBack = false
+ if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" {
+ encrypted.RelatesTo = nil
+ }
+ }
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
@@ -351,26 +370,19 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
+ logUsers := zerolog.Dict()
for userID, sessions := range olmSessions {
if len(sessions) == 0 {
continue
}
+ logDevices := zerolog.Dict()
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
- log.Trace().
- Stringer("target_user_id", userID).
- Stringer("target_device_id", deviceID).
- Stringer("target_identity_key", device.identity.IdentityKey).
- Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
+ logDevices.Str(string(deviceID), string(device.identity.IdentityKey))
deviceCount++
- log.Debug().
- Stringer("target_user_id", userID).
- Stringer("target_device_id", deviceID).
- Stringer("target_identity_key", device.identity.IdentityKey).
- Msg("Encrypted group session for device")
if !mach.DisableSharedGroupSessionTracking {
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
if err != nil {
@@ -384,11 +396,13 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
}
}
}
+ logUsers.Dict(string(userID), logDevices)
}
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
+ Dict("destination_map", logUsers).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
return err
diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go
index 80b76dc5..765307af 100644
--- a/crypto/encryptolm.go
+++ b/crypto/encryptolm.go
@@ -96,15 +96,19 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
panic(err)
}
log := mach.machOrContextLog(ctx)
- log.Debug().
- Str("recipient_identity_key", recipient.IdentityKey.String()).
- Str("olm_session_id", session.ID().String()).
- Str("olm_session_description", session.Describe()).
- Msg("Encrypting olm message")
msgType, ciphertext, err := session.Encrypt(plaintext)
if err != nil {
panic(err)
}
+ ciphertextStr := string(ciphertext)
+ ciphertextHash, _ := olmMessageHash(ciphertextStr)
+ log.Debug().
+ Stringer("event_type", evtType).
+ Str("recipient_identity_key", recipient.IdentityKey.String()).
+ Str("olm_session_id", session.ID().String()).
+ Str("olm_session_description", session.Describe()).
+ Hex("ciphertext_hash", ciphertextHash[:]).
+ Msg("Encrypted olm message")
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
@@ -115,7 +119,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
OlmCiphertext: event.OlmCiphertexts{
recipient.IdentityKey: {
Type: msgType,
- Body: string(ciphertext),
+ Body: ciphertextStr,
},
},
}
diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go
index 4da08a73..b48843a4 100644
--- a/crypto/goolm/account/account.go
+++ b/crypto/goolm/account/account.go
@@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
if err != nil {
return err
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
- return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
return err
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go
index e1c9b452..d0dec5f0 100644
--- a/crypto/goolm/account/account_test.go
+++ b/crypto/goolm/account/account_test.go
@@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
account, err := account.NewAccount()
assert.NoError(t, err)
err = account.Unpickle(pickled, pickleKey)
- assert.ErrorIs(t, err, olm.ErrBadVersion)
+ assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
}
func TestLoopback(t *testing.T) {
diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go
index e9759501..6e42d886 100644
--- a/crypto/goolm/crypto/curve25519.go
+++ b/crypto/goolm/crypto/curve25519.go
@@ -53,6 +53,7 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
// SharedSecret returns the shared secret between the key pair and the given public key.
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
+ // Note: the standard library checks that the output is non-zero
return c.PrivateKey.SharedSecret(pubKey)
}
diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go
index 9039c126..2550f15e 100644
--- a/crypto/goolm/crypto/curve25519_test.go
+++ b/crypto/goolm/crypto/curve25519_test.go
@@ -25,6 +25,8 @@ func TestCurve25519(t *testing.T) {
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
assert.NoError(t, err)
assert.Equal(t, fromPrivate, firstKeypair)
+ _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength))
+ assert.Error(t, err)
}
func TestCurve25519Case1(t *testing.T) {
diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go
index 061a052a..58ee26f7 100644
--- a/crypto/goolm/goolmbase64/base64.go
+++ b/crypto/goolm/goolmbase64/base64.go
@@ -4,7 +4,8 @@ import (
"encoding/base64"
)
-// Deprecated: base64.RawStdEncoding should be used directly
+// These methods should only be used for raw byte operations, never with string conversion
+
func Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
@@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) {
return decoded[:writtenBytes], nil
}
-// Deprecated: base64.RawStdEncoding should be used directly
func Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)
diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go
index 308e472c..f765391f 100644
--- a/crypto/goolm/libolmpickle/picklejson.go
+++ b/crypto/goolm/libolmpickle/picklejson.go
@@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
}
}
if decrypted[0] != pickleVersion {
- return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
+ return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {
diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go
index f3d22500..c83540c1 100644
--- a/crypto/goolm/message/group_message.go
+++ b/crypto/goolm/message/group_message.go
@@ -39,7 +39,7 @@ func (r *GroupMessage) Decode(input []byte) (err error) {
return
}
if r.Version != protocolVersion {
- return fmt.Errorf("GroupMessage.Decode: %w", olm.ErrWrongProtocolVersion)
+ return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
}
for {
diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go
index 9ef93630..b161a2d1 100644
--- a/crypto/goolm/message/message.go
+++ b/crypto/goolm/message/message.go
@@ -43,7 +43,7 @@ func (r *Message) Decode(input []byte) (err error) {
return
}
if r.Version != protocolVersion {
- return fmt.Errorf("Message.Decode: %w", olm.ErrWrongProtocolVersion)
+ return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
}
for {
diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go
index 760be4c9..4e3d495d 100644
--- a/crypto/goolm/message/prekey_message.go
+++ b/crypto/goolm/message/prekey_message.go
@@ -48,7 +48,7 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) {
return
}
if r.Version != protocolVersion {
- return fmt.Errorf("PreKeyMessage.Decode: %w", olm.ErrWrongProtocolVersion)
+ return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
}
for {
diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go
index 956868b2..d58dbb21 100644
--- a/crypto/goolm/message/session_export.go
+++ b/crypto/goolm/message/session_export.go
@@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error {
return fmt.Errorf("decrypt: %w", olm.ErrBadInput)
}
if input[0] != sessionExportVersion {
- return fmt.Errorf("decrypt: %w", olm.ErrBadVersion)
+ return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go
index 16240945..d04ef15a 100644
--- a/crypto/goolm/message/session_sharing.go
+++ b/crypto/goolm/message/session_sharing.go
@@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
}
s.PublicKey = publicKey
if input[0] != sessionSharingVersion {
- return fmt.Errorf("verify: %w", olm.ErrBadVersion)
+ return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go
index afb01f74..cdb20eb1 100644
--- a/crypto/goolm/pk/decryption.go
+++ b/crypto/goolm/pk/decryption.go
@@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error {
if pickledVersion == decryptionPickleVersionLibOlm {
return a.KeyPair.UnpickleLibOlm(decoder)
} else {
- return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm)
+ return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm)
}
}
diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go
index 23f67ddf..2897d9b0 100644
--- a/crypto/goolm/pk/encryption.go
+++ b/crypto/goolm/pk/encryption.go
@@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat
return nil, nil, err
}
cipher, err := aessha2.NewAESSHA2(sharedSecret, nil)
+ if err != nil {
+ return nil, nil, err
+ }
ciphertext, err = cipher.Encrypt(plaintext)
if err != nil {
return nil, nil, err
diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go
index 229c9bd2..9901ada8 100644
--- a/crypto/goolm/ratchet/olm.go
+++ b/crypto/goolm/ratchet/olm.go
@@ -142,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
return nil, err
}
if message.Version != protocolVersion {
- return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
+ return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion)
}
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go
index 80dd71cc..7ccbd26d 100644
--- a/crypto/goolm/session/megolm_inbound_session.go
+++ b/crypto/goolm/session/megolm_inbound_session.go
@@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
}
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
// the counter is before our initial ratchet - we can't decode this
- return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable)
+ return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex)
}
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
copiedRatchet := o.InitialRatchet
@@ -126,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error)
return nil, 0, err
}
if msg.Version != protocolVersion {
- return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion)
+ return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion)
}
if msg.Ciphertext == nil || !msg.HasMessageIndex {
return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat)
@@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error {
return err
}
if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 {
- return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil {
diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go
index 2b8e1c84..7f923534 100644
--- a/crypto/goolm/session/megolm_outbound_session.go
+++ b/crypto/goolm/session/megolm_outbound_session.go
@@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
- if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
- return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ if err != nil {
+ return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err)
+ } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm {
+ return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil {
return err
diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go
index b99ab630..a1cb8d66 100644
--- a/crypto/goolm/session/olm_session.go
+++ b/crypto/goolm/session/olm_session.go
@@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received
msg := message.Message{}
err = msg.Decode(oneTimeMsg.Message)
if err != nil {
- return nil, fmt.Errorf("Message decode: %w", err)
+ return nil, fmt.Errorf("message decode: %w", err)
}
if len(msg.RatchetKey) == 0 {
- return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat)
+ return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat)
}
//Init Ratchet
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
@@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID {
copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey)
hash := sha256.Sum256(message)
- res := id.SessionID(goolmbase64.Encode(hash[:]))
+ res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:]))
return res
}
@@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput)
}
- decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext))
+ decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext)
if err != nil {
return nil, err
}
@@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error {
func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
decoder := libolmpickle.NewDecoder(buf)
pickledVersion, err := decoder.ReadUInt32()
+ if err != nil {
+ return fmt.Errorf("unpickle olmSession: failed to read version: %w", err)
+ }
var includesChainIndex bool
switch pickledVersion {
@@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error {
case uint32(0x80000001):
includesChainIndex = true
default:
- return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
+ return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
}
if o.ReceivedMessage, err = decoder.ReadBool(); err != nil {
diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go
index a88d12f6..b95a44ac 100644
--- a/crypto/goolm/session/register.go
+++ b/crypto/goolm/session/register.go
@@ -14,7 +14,7 @@ func Register() {
// Inbound Session
olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
if len(key) == 0 {
key = []byte(" ")
@@ -23,13 +23,13 @@ func Register() {
}
olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
return NewMegolmInboundSession(sessionKey)
}
olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
return NewMegolmInboundSessionFromExport(sessionKey)
}
@@ -40,7 +40,7 @@ func Register() {
// Outbound Session
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
lenKey := len(key)
if lenKey == 0 {
diff --git a/crypto/keybackup.go b/crypto/keybackup.go
index d8b3d715..7b3c30db 100644
--- a/crypto/keybackup.go
+++ b/crypto/keybackup.go
@@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context,
// ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private
// key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing
// from a verified device belonging to the same user."
- megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
- if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
- log.Debug().Msg("key backup is trusted based on derived public key")
- return versionInfo, nil
- } else {
+ if megolmBackupKey != nil {
+ megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes()))
+ if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey {
+ log.Debug().Msg("Key backup is trusted based on derived public key")
+ return versionInfo, nil
+ }
log.Debug().
Stringer("expected_key", megolmBackupDerivedPublicKey).
Stringer("actual_key", versionInfo.AuthData.PublicKey).
@@ -199,13 +200,14 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving(
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,
- ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
+ ForwardingChains: keyBackupData.ForwardingKeyChain,
id: sessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
KeyBackupVersion: version,
+ KeySource: id.KeySourceBackup,
}, nil
}
diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go
index 47616a20..fd6f105d 100644
--- a/crypto/keyexport_test.go
+++ b/crypto/keyexport_test.go
@@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) {
))
data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess})
assert.NoError(t, err)
- assert.Len(t, data, 836)
+ assert.Len(t, data, 893)
}
diff --git a/crypto/keyimport.go b/crypto/keyimport.go
index 36ad6b9c..3ffc74a5 100644
--- a/crypto/keyimport.go
+++ b/crypto/keyimport.go
@@ -108,19 +108,20 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
return false, ErrMismatchingExportedSessionID
}
igs := &InboundGroupSession{
- Internal: igsInternal,
- SigningKey: session.SenderClaimedKeys.Ed25519,
- SenderKey: session.SenderKey,
- RoomID: session.RoomID,
- // TODO should we add something here to mark the signing key as unverified like key requests do?
+ Internal: igsInternal,
+ SigningKey: session.SenderClaimedKeys.Ed25519,
+ SenderKey: session.SenderKey,
+ RoomID: session.RoomID,
ForwardingChains: session.ForwardingChains,
-
- ReceivedAt: time.Now().UTC(),
+ KeySource: id.KeySourceImport,
+ ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
firstKnownIndex := igs.Internal.FirstKnownIndex()
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
- // We already have an equivalent or better session in the store, so don't override it.
+ // We already have an equivalent or better session in the store, so don't override it,
+ // but do notify the session received callback just in case.
+ mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex())
return false, nil
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
diff --git a/crypto/keysharing.go b/crypto/keysharing.go
index f1d427af..19a68c87 100644
--- a/crypto/keysharing.go
+++ b/crypto/keysharing.go
@@ -189,6 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
+ KeySource: id.KeySourceForward,
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
@@ -214,6 +215,7 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare
RoomID: request.RoomID,
Algorithm: request.Algorithm,
SessionID: request.SessionID,
+ //lint:ignore SA1019 This is just echoing back the deprecated field
SenderKey: request.SenderKey,
Code: rejection.Code,
Reason: rejection.Reason,
@@ -263,9 +265,14 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev
log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing")
return &KeyShareRejectNoResponse
} else if !isShared {
- // TODO differentiate session not shared with requester vs session not created by this device?
- log.Debug().Msg("Rejecting key request for unshared session")
- return &KeyShareRejectNotRecipient
+ igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID)
+ if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey {
+ log.Debug().Msg("Rejecting key request for unshared session")
+ return &KeyShareRejectNotRecipient
+ }
+ // Note: this case will also happen for redacted sessions and database errors
+ log.Debug().Msg("Rejecting key request for session created by another device")
+ return &KeyShareRejectNoResponse
}
log.Debug().Msg("Accepting key request for shared session")
return nil
@@ -323,7 +330,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
- mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ if sender != mach.Client.UserID {
+ mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ }
} else {
log.Error().Err(err).Msg("Failed to get group session to forward")
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
@@ -331,7 +340,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
return
} else if igs == nil {
log.Error().Msg("Didn't find group session to forward")
- mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ if sender != mach.Client.UserID {
+ mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
+ }
return
}
if internalID := igs.ID(); internalID != content.Body.SessionID {
@@ -356,7 +367,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
SessionID: igs.ID(),
SessionKey: string(exportedKey),
},
- SenderKey: content.Body.SenderKey,
+ SenderKey: igs.SenderKey,
ForwardingKeyChain: igs.ForwardingChains,
SenderClaimedKey: igs.SigningKey,
},
diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go
index f6f916e7..0350f083 100644
--- a/crypto/libolm/account.go
+++ b/crypto/libolm/account.go
@@ -33,7 +33,7 @@ var _ olm.Account = (*Account)(nil)
// "INVALID_BASE64".
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
a := NewBlankAccount()
return a, a.Unpickle(pickled, key)
@@ -53,7 +53,7 @@ func NewAccount() (*Account, error) {
random := make([]byte, a.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
- panic(olm.NotEnoughGoRandom)
+ panic(olm.ErrNotEnoughGoRandom)
}
ret := C.olm_create_account(
(*C.OlmAccount)(a.int),
@@ -128,7 +128,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
// supplied key.
func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, a.pickleLen())
r := C.olm_pickle_account(
@@ -145,7 +145,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) {
func (a *Account) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_account(
(*C.OlmAccount)(a.int),
@@ -198,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if a.int == nil {
*a = *NewBlankAccount()
@@ -235,7 +235,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
// Account.
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
- panic(olm.EmptyInput)
+ panic(olm.ErrEmptyInput)
}
signature := make([]byte, a.signatureLen())
r := C.olm_account_sign(
@@ -299,7 +299,7 @@ func (a *Account) GenOneTimeKeys(num uint) error {
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
_, err := rand.Read(random)
if err != nil {
- return olm.NotEnoughGoRandom
+ return olm.ErrNotEnoughGoRandom
}
r := C.olm_account_generate_one_time_keys(
(*C.OlmAccount)(a.int),
@@ -319,13 +319,13 @@ func (a *Account) GenOneTimeKeys(num uint) error {
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
random := make([]byte, s.createOutboundRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
- panic(olm.NotEnoughGoRandom)
+ panic(olm.ErrNotEnoughGoRandom)
}
theirIdentityKeyCopy := []byte(theirIdentityKey)
theirOneTimeKeyCopy := []byte(theirOneTimeKey)
@@ -357,7 +357,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
if len(oneTimeKeyMsg) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
@@ -383,7 +383,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
theirIdentityKeyCopy := []byte(*theirIdentityKey)
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go
index 9ca415ee..6fb5512b 100644
--- a/crypto/libolm/error.go
+++ b/crypto/libolm/error.go
@@ -11,21 +11,21 @@ import (
)
var errorMap = map[string]error{
- "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom,
- "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall,
- "BAD_MESSAGE_VERSION": olm.BadMessageVersion,
- "BAD_MESSAGE_FORMAT": olm.BadMessageFormat,
- "BAD_MESSAGE_MAC": olm.BadMessageMAC,
- "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID,
- "INVALID_BASE64": olm.InvalidBase64,
- "BAD_ACCOUNT_KEY": olm.BadAccountKey,
- "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion,
- "CORRUPTED_PICKLE": olm.CorruptedPickle,
- "BAD_SESSION_KEY": olm.BadSessionKey,
- "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex,
- "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle,
- "BAD_SIGNATURE": olm.BadSignature,
- "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall,
+ "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom,
+ "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall,
+ "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion,
+ "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat,
+ "BAD_MESSAGE_MAC": olm.ErrBadMAC,
+ "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID,
+ "INVALID_BASE64": olm.ErrLibolmInvalidBase64,
+ "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey,
+ "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion,
+ "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle,
+ "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey,
+ "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex,
+ "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle,
+ "BAD_SIGNATURE": olm.ErrBadSignature,
+ "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall,
}
func convertError(errCode string) error {
diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go
index 5606475d..8815ac32 100644
--- a/crypto/libolm/inboundgroupsession.go
+++ b/crypto/libolm/inboundgroupsession.go
@@ -31,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil)
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
lenKey := len(key)
if lenKey == 0 {
@@ -48,7 +48,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
// "OLM_BAD_SESSION_KEY".
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_init_inbound_group_session(
@@ -69,7 +69,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
// error will be "OLM_BAD_SESSION_KEY".
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankInboundGroupSession()
r := C.olm_import_inbound_group_session(
@@ -124,7 +124,7 @@ func (s *InboundGroupSession) pickleLen() uint {
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_inbound_group_session(
@@ -143,9 +143,9 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
} else if len(pickled) == 0 {
- return olm.EmptyInput
+ return olm.ErrEmptyInput
}
r := C.olm_unpickle_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
@@ -200,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankInboundGroupSession()
@@ -217,7 +217,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
// will be "BAD_MESSAGE_FORMAT".
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
if len(message) == 0 {
- return 0, olm.EmptyInput
+ return 0, olm.ErrEmptyInput
}
// olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it
messageCopy := bytes.Clone(message)
@@ -244,7 +244,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
- return nil, 0, olm.EmptyInput
+ return nil, 0, olm.ErrEmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message)
if err != nil {
diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go
index 646929eb..ca5b68f7 100644
--- a/crypto/libolm/outboundgroupsession.go
+++ b/crypto/libolm/outboundgroupsession.go
@@ -84,7 +84,7 @@ func (s *OutboundGroupSession) pickleLen() uint {
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_outbound_group_session(
@@ -103,7 +103,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) {
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_outbound_group_session(
(*C.OlmOutboundGroupSession)(s.int),
@@ -159,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
@@ -183,7 +183,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if len(plaintext) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_group_encrypt(
diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go
index 35532140..2683cf15 100644
--- a/crypto/libolm/pk.go
+++ b/crypto/libolm/pk.go
@@ -86,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) {
seed := make([]byte, pkSigningSeedLength())
_, err := rand.Read(seed)
if err != nil {
- panic(olm.NotEnoughGoRandom)
+ panic(olm.ErrNotEnoughGoRandom)
}
pk, err := NewPKSigningFromSeed(seed)
return pk, err
diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go
index f091d822..ddf84613 100644
--- a/crypto/libolm/register.go
+++ b/crypto/libolm/register.go
@@ -65,7 +65,7 @@ func Register() {
olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankOutboundGroupSession()
return s, s.Unpickle(pickled, key)
diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go
index 57e631c3..1441df26 100644
--- a/crypto/libolm/session.go
+++ b/crypto/libolm/session.go
@@ -51,7 +51,7 @@ func sessionSize() uint {
// "INVALID_BASE64".
func SessionFromPickled(pickled, key []byte) (*Session, error) {
if len(pickled) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
s := NewBlankSession()
return s, s.Unpickle(pickled, key)
@@ -118,7 +118,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
// will be "BAD_MESSAGE_FORMAT".
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
if len(message) == 0 {
- return 0, olm.EmptyInput
+ return 0, olm.ErrEmptyInput
}
messageCopy := []byte(message)
r := C.olm_decrypt_max_plaintext_length(
@@ -138,7 +138,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
// supplied key.
func (s *Session) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
- return nil, olm.NoKeyProvided
+ return nil, olm.ErrNoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_session(
@@ -158,7 +158,7 @@ func (s *Session) Pickle(key []byte) ([]byte, error) {
// provided key. This function mutates the input pickled data slice.
func (s *Session) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
- return olm.NoKeyProvided
+ return olm.ErrNoKeyProvided
}
r := C.olm_unpickle_session(
(*C.OlmSession)(s.int),
@@ -213,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
- return olm.InputNotJSONString
+ return olm.ErrInputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankSession()
@@ -256,7 +256,7 @@ func (s *Session) HasReceivedMessage() bool {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
if len(oneTimeKeyMsg) == 0 {
- return false, olm.EmptyInput
+ return false, olm.ErrEmptyInput
}
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
r := C.olm_matches_inbound_session(
@@ -284,7 +284,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
- return false, olm.EmptyInput
+ return false, olm.ErrEmptyInput
}
theirIdentityKeyCopy := []byte(theirIdentityKey)
oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
@@ -325,14 +325,14 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
- return 0, nil, olm.EmptyInput
+ return 0, nil, olm.ErrEmptyInput
}
// Make the slice be at least length 1
random := make([]byte, s.encryptRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
// TODO can we just return err here?
- return 0, nil, olm.NotEnoughGoRandom
+ return 0, nil, olm.ErrNotEnoughGoRandom
}
messageType := s.EncryptMsgType()
message := make([]byte, s.encryptMsgLen(len(plaintext)))
@@ -362,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
// "BAD_MESSAGE_MAC".
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
- return nil, olm.EmptyInput
+ return nil, olm.ErrEmptyInput
}
decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType)
if err != nil {
diff --git a/crypto/machine.go b/crypto/machine.go
index 4d2e3880..fa051f94 100644
--- a/crypto/machine.go
+++ b/crypto/machine.go
@@ -39,6 +39,7 @@ type OlmMachine struct {
cancelBackgroundCtx context.CancelFunc
PlaintextMentions bool
+ MSC4392Relations bool
AllowEncryptedState bool
// Never ask the server for keys automatically as a side effect during Megolm decryption.
@@ -205,7 +206,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error {
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
start := time.Now()
return func() {
- duration := time.Now().Sub(start)
+ duration := time.Since(start)
if duration > expectedDuration {
zerolog.Ctx(ctx).Warn().
Str("action", thing).
diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go
index 957d7928..9e522b2a 100644
--- a/crypto/olm/errors.go
+++ b/crypto/olm/errors.go
@@ -10,50 +10,67 @@ import "errors"
// Those are the most common used errors
var (
- ErrBadSignature = errors.New("bad signature")
- ErrBadMAC = errors.New("bad mac")
- ErrBadMessageFormat = errors.New("bad message format")
- ErrBadVerification = errors.New("bad verification")
- ErrWrongProtocolVersion = errors.New("wrong protocol version")
- ErrEmptyInput = errors.New("empty input")
- ErrNoKeyProvided = errors.New("no key")
- ErrBadMessageKeyID = errors.New("bad message key id")
- ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
- ErrMsgIndexTooHigh = errors.New("message index too high")
- ErrProtocolViolation = errors.New("not protocol message order")
- ErrMessageKeyNotFound = errors.New("message key not found")
- ErrChainTooHigh = errors.New("chain index too high")
- ErrBadInput = errors.New("bad input")
- ErrBadVersion = errors.New("wrong version")
- ErrWrongPickleVersion = errors.New("wrong pickle version")
- ErrInputToSmall = errors.New("input too small (truncated?)")
- ErrOverflow = errors.New("overflow")
+ ErrBadSignature = errors.New("bad signature")
+ ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)")
+ ErrBadMessageFormat = errors.New("the message couldn't be decoded")
+ ErrBadVerification = errors.New("bad verification")
+ ErrWrongProtocolVersion = errors.New("wrong protocol version")
+ ErrEmptyInput = errors.New("empty input")
+ ErrNoKeyProvided = errors.New("no key provided")
+ ErrBadMessageKeyID = errors.New("the message references an unknown key ID")
+ ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
+ ErrMsgIndexTooHigh = errors.New("message index too high")
+ ErrProtocolViolation = errors.New("not protocol message order")
+ ErrMessageKeyNotFound = errors.New("message key not found")
+ ErrChainTooHigh = errors.New("chain index too high")
+ ErrBadInput = errors.New("bad input")
+ ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version")
+ ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version")
+ ErrInputToSmall = errors.New("input too small (truncated?)")
)
// Error codes from go-olm
var (
- EmptyInput = errors.New("empty input")
- NoKeyProvided = errors.New("no pickle key provided")
- NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
- SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// Error codes from olm code
var (
- NotEnoughRandom = errors.New("not enough entropy was supplied")
- OutputBufferTooSmall = errors.New("supplied output buffer is too small")
- BadMessageVersion = errors.New("the message version is unsupported")
- BadMessageFormat = errors.New("the message couldn't be decoded")
- BadMessageMAC = errors.New("the message couldn't be decrypted")
- BadMessageKeyID = errors.New("the message references an unknown key ID")
- InvalidBase64 = errors.New("the input base64 was invalid")
- BadAccountKey = errors.New("the supplied account key is invalid")
- UnknownPickleVersion = errors.New("the pickled object is too new")
- CorruptedPickle = errors.New("the pickled object couldn't be decoded")
- BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
- UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key")
- BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
- BadSignature = errors.New("received message had a bad signature")
- InputBufferTooSmall = errors.New("the input data was too small to be valid")
+ ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid")
+
+ ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied")
+ ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small")
+ ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid")
+ ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded")
+ ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key")
+ ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ EmptyInput = ErrEmptyInput
+ BadSignature = ErrBadSignature
+ InvalidBase64 = ErrLibolmInvalidBase64
+ BadMessageKeyID = ErrBadMessageKeyID
+ BadMessageFormat = ErrBadMessageFormat
+ BadMessageVersion = ErrWrongProtocolVersion
+ BadMessageMAC = ErrBadMAC
+ UnknownPickleVersion = ErrUnknownOlmPickleVersion
+ NotEnoughRandom = ErrLibolmNotEnoughRandom
+ OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall
+ BadAccountKey = ErrLibolmBadAccountKey
+ CorruptedPickle = ErrLibolmCorruptedPickle
+ BadSessionKey = ErrLibolmBadSessionKey
+ UnknownMessageIndex = ErrUnknownMessageIndex
+ BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle
+ InputBufferTooSmall = ErrInputToSmall
+ NoKeyProvided = ErrNoKeyProvided
+
+ NotEnoughGoRandom = ErrNotEnoughGoRandom
+ InputNotJSONString = ErrInputNotJSONString
+
+ ErrBadVersion = ErrUnknownJSONPickleVersion
+ ErrWrongPickleVersion = ErrUnknownJSONPickleVersion
+ ErrRatchetNotAvailable = ErrUnknownMessageIndex
)
diff --git a/crypto/sessions.go b/crypto/sessions.go
index aecb0416..ccc7b784 100644
--- a/crypto/sessions.go
+++ b/crypto/sessions.go
@@ -18,8 +18,14 @@ import (
)
var (
- SessionNotShared = errors.New("session has not been shared")
- SessionExpired = errors.New("session has expired")
+ ErrSessionNotShared = errors.New("session has not been shared")
+ ErrSessionExpired = errors.New("session has expired")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ SessionNotShared = ErrSessionNotShared
+ SessionExpired = ErrSessionExpired
)
// OlmSessionList is a list of OlmSessions.
@@ -111,6 +117,7 @@ type InboundGroupSession struct {
MaxMessages int
IsScheduled bool
KeyBackupVersion id.KeyBackupVersion
+ KeySource id.KeySource
id id.SessionID
}
@@ -130,6 +137,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: isScheduled,
+ KeySource: id.KeySourceDirect,
}, nil
}
@@ -163,7 +171,7 @@ func (igs *InboundGroupSession) export() (*ExportedSession, error) {
ForwardingChains: igs.ForwardingChains,
RoomID: igs.RoomID,
SenderKey: igs.SenderKey,
- SenderClaimedKeys: SenderClaimedKeys{},
+ SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey},
SessionID: igs.ID(),
SessionKey: string(key),
}, nil
@@ -255,9 +263,9 @@ func (ogs *OutboundGroupSession) Expired() bool {
func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if !ogs.Shared {
- return nil, SessionNotShared
+ return nil, ErrSessionNotShared
} else if ogs.Expired() {
- return nil, SessionExpired
+ return nil, ErrSessionExpired
}
ogs.MessageCount++
ogs.LastEncryptedTime = time.Now()
diff --git a/crypto/sql_store.go b/crypto/sql_store.go
index ca75b3f6..138cc557 100644
--- a/crypto/sql_store.go
+++ b/crypto/sql_store.go
@@ -346,22 +346,23 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou
Int("max_messages", session.MaxMessages).
Bool("is_scheduled", session.IsScheduled).
Stringer("key_backup_version", session.KeyBackupVersion).
+ Stringer("key_source", session.KeySource).
Msg("Upserting megolm inbound group session")
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
- ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
- key_backup_version=excluded.key_backup_version
+ key_backup_version=excluded.key_backup_version, key_source=excluded.key_source
`,
session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages),
- session.IsScheduled, session.KeyBackupVersion, store.AccountID,
+ session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID,
)
return err
}
@@ -374,12 +375,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
+ var keySource id.KeySource
err := store.DB.QueryRow(ctx, `
- SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND session_id=$2 AND account_id=$3`,
roomID, sessionID, store.AccountID,
- ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
+ ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
@@ -410,6 +412,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
+ KeySource: keySource,
}, nil
}
@@ -534,7 +537,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
var version id.KeyBackupVersion
- err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
+ var keySource id.KeySource
+ err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource)
if err != nil {
return nil, err
}
@@ -554,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
KeyBackupVersion: version,
+ KeySource: keySource,
}, nil
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@@ -568,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
store.AccountID,
)
@@ -577,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
- SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
+ SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
store.AccountID, version,
)
diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql
index af8ab5cc..3709f1e5 100644
--- a/crypto/sql_store_upgrade/00-latest-revision.sql
+++ b/crypto/sql_store_upgrade/00-latest-revision.sql
@@ -1,4 +1,4 @@
--- v0 -> v18 (compatible with v15+): Latest revision
+-- v0 -> v19 (compatible with v15+): Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@@ -71,6 +71,7 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
key_backup_version TEXT NOT NULL DEFAULT '',
+ key_source TEXT NOT NULL DEFAULT '',
PRIMARY KEY (account_id, session_id)
);
-- Useful index to find keys that need backing up
diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql
new file mode 100644
index 00000000..f624222f
--- /dev/null
+++ b/crypto/sql_store_upgrade/19-megolm-session-source.sql
@@ -0,0 +1,2 @@
+-- v19 (compatible with v15+): Store megolm session source
+ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT '';
diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go
index e30925d9..8691d032 100644
--- a/crypto/ssss/client.go
+++ b/crypto/ssss/client.go
@@ -95,6 +95,22 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even
return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
}
+// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it,
+// alongside the unencrypted metadata, on the server.
+func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error {
+ if len(keys) == 0 {
+ return ErrNoKeyGiven
+ }
+ encrypted := make(map[string]EncryptedKeyData, len(keys))
+ for _, key := range keys {
+ encrypted[key.ID] = key.Encrypt(eventType.Type, data)
+ }
+ return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{
+ Encrypted: encrypted,
+ Metadata: metadata,
+ })
+}
+
// GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server.
func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) {
key, err = NewKey(passphrase)
diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go
index cd8e3fce..78ebd8f3 100644
--- a/crypto/ssss/key.go
+++ b/crypto/ssss/key.go
@@ -59,12 +59,12 @@ func NewKey(passphrase string) (*Key, error) {
// We store a certain hash in the key metadata so that clients can check if the user entered the correct key.
ivBytes := random.Bytes(utils.AESCTRIVLength)
keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes)
- var err error
- keyData.MAC, err = keyData.calculateHash(ssssKey)
+ macBytes, err := keyData.calculateHash(ssssKey)
if err != nil {
// This should never happen because we just generated the IV and key.
return nil, fmt.Errorf("failed to calculate hash: %w", err)
}
+ keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes)
return &Key{
Key: ssssKey,
diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go
index 474c85d8..34775fa7 100644
--- a/crypto/ssss/meta.go
+++ b/crypto/ssss/meta.go
@@ -7,7 +7,10 @@
package ssss
import (
+ "crypto/hmac"
+ "crypto/sha256"
"encoding/base64"
+ "errors"
"fmt"
"strings"
@@ -33,7 +36,9 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error)
ssssKey, err := kd.Passphrase.GetKey(passphrase)
if err != nil {
return nil, err
- } else if err = kd.verifyKey(ssssKey); err != nil {
+ }
+ err = kd.verifyKey(ssssKey)
+ if err != nil && !errors.Is(err, ErrUnverifiableKey) {
return nil, err
}
@@ -49,7 +54,9 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error
ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey)
if ssssKey == nil {
return nil, ErrInvalidRecoveryKey
- } else if err := kd.verifyKey(ssssKey); err != nil {
+ }
+ err := kd.verifyKey(ssssKey)
+ if err != nil && !errors.Is(err, ErrUnverifiableKey) {
return nil, err
}
@@ -57,20 +64,28 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error
ID: keyID,
Key: ssssKey,
Metadata: kd,
- }, nil
+ }, err
}
func (kd *KeyMetadata) verifyKey(key []byte) error {
+ if kd.MAC == "" || kd.IV == "" {
+ return ErrUnverifiableKey
+ }
unpaddedMAC := strings.TrimRight(kd.MAC, "=")
expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength)
if len(unpaddedMAC) != expectedMACLength {
return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength)
}
- hash, err := kd.calculateHash(key)
+ expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC)
+ if err != nil {
+ return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err)
+ }
+ calculatedMAC, err := kd.calculateHash(key)
if err != nil {
return err
}
- if unpaddedMAC != hash {
+ // This doesn't really need to be constant time since it's fully local, but might as well be.
+ if !hmac.Equal(expectedMAC, calculatedMAC) {
return ErrIncorrectSSSSKey
}
return nil
@@ -83,23 +98,26 @@ func (kd *KeyMetadata) VerifyKey(key []byte) bool {
// calculateHash calculates the hash used for checking if the key is entered correctly as described
// in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2
-func (kd *KeyMetadata) calculateHash(key []byte) (string, error) {
+func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) {
aesKey, hmacKey := utils.DeriveKeysSHA256(key, "")
unpaddedIV := strings.TrimRight(kd.IV, "=")
expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength)
- if len(unpaddedIV) != expectedIVLength {
- return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength)
+ if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 {
+ return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength)
}
-
- var ivBytes [utils.AESCTRIVLength]byte
- _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV))
+ rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV)
if err != nil {
- return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err)
+ return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err)
}
+ // TODO log a warning for non-16 byte IVs?
+ // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used.
+ ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength])
- cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes)
-
- return utils.HMACSHA256B64(cipher, hmacKey), nil
+ zeroes := make([]byte, utils.AESCTRKeyLength)
+ encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes)
+ h := hmac.New(sha256.New, hmacKey[:])
+ h.Write(encryptedZeroes)
+ return h.Sum(nil), nil
}
// PassphraseMetadata represents server-side metadata about a SSSS key passphrase.
diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go
index 7a5ef8b9..d59809c7 100644
--- a/crypto/ssss/meta_test.go
+++ b/crypto/ssss/meta_test.go
@@ -8,7 +8,6 @@ package ssss_test
import (
"encoding/json"
- "errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -42,10 +41,24 @@ const key2Meta = `
}
`
+const key2MetaUnverified = `
+{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2"
+}
+`
+
+const key2MetaLongIV = `
+{
+ "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+ "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=",
+ "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
+}
+`
+
const key2MetaBrokenIV = `
{
"algorithm": "m.secret_storage.v1.aes-hmac-sha2",
- "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow",
+ "iv": "MeowMeowMeow",
"mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI="
}
`
@@ -94,17 +107,33 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) {
assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
}
+func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) {
+ km := getKeyMeta(key2MetaLongIV)
+ key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
+ assert.NoError(t, err)
+ assert.NotNil(t, key)
+ assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
+}
+
+func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) {
+ km := getKeyMeta(key2MetaUnverified)
+ key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
+ assert.ErrorIs(t, err, ssss.ErrUnverifiableKey)
+ assert.NotNil(t, key)
+ assert.Equal(t, key2RecoveryKey, key.RecoveryKey())
+}
+
func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyRecoveryKey(key1ID, "foo")
- assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err)
+ assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err)
+ assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey)
assert.Nil(t, key)
}
@@ -119,27 +148,27 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) {
func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) {
km := getKeyMeta(key1Meta)
key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple")
- assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) {
km := getKeyMeta(key2Meta)
key, err := km.VerifyPassphrase(key2ID, "hmm")
- assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrNoPassphrase)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) {
km := getKeyMeta(key2MetaBrokenIV)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata)
assert.Nil(t, key)
}
func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) {
km := getKeyMeta(key2MetaBrokenMAC)
key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey)
- assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err)
+ assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata)
assert.Nil(t, key)
}
diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go
index 345393b0..b7465d3e 100644
--- a/crypto/ssss/types.go
+++ b/crypto/ssss/types.go
@@ -26,7 +26,8 @@ var (
ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm")
ErrIncorrectSSSSKey = errors.New("incorrect SSSS key")
ErrInvalidRecoveryKey = errors.New("invalid recovery key")
- ErrCorruptedKeyMetadata = errors.New("corrupted key metadata")
+ ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata")
+ ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata")
)
// Algorithm is the identifier for an SSSS encryption algorithm.
@@ -57,6 +58,7 @@ type EncryptedKeyData struct {
type EncryptedAccountDataEventContent struct {
Encrypted map[string]EncryptedKeyData `json:"encrypted"`
+ Metadata map[string]any `json:"com.beeper.metadata,omitzero"`
}
func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) {
diff --git a/error.go b/error.go
index 826af179..4711b3dc 100644
--- a/error.go
+++ b/error.go
@@ -85,6 +85,10 @@ var (
ErrResponseTooLong = errors.New("response content length too long")
ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body")
+
+ // Special error that indicates we should retry canceled contexts. Note that on it's own this
+ // is useless, the context itself must also be replaced.
+ ErrContextCancelRetry = errors.New("retry canceled context")
)
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
@@ -136,7 +140,10 @@ type RespError struct {
Err string
ExtraData map[string]any
- StatusCode int
+ StatusCode int
+ ExtraHeader map[string]string
+
+ CanRetry bool
}
func (e *RespError) UnmarshalJSON(data []byte) error {
@@ -146,6 +153,7 @@ func (e *RespError) UnmarshalJSON(data []byte) error {
}
e.ErrCode, _ = e.ExtraData["errcode"].(string)
e.Err, _ = e.ExtraData["error"].(string)
+ e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool)
return nil
}
@@ -153,6 +161,9 @@ func (e *RespError) MarshalJSON() ([]byte, error) {
data := exmaps.NonNilClone(e.ExtraData)
data["errcode"] = e.ErrCode
data["error"] = e.Err
+ if e.CanRetry {
+ data["com.beeper.can_retry"] = e.CanRetry
+ }
return json.Marshal(data)
}
@@ -164,6 +175,9 @@ func (e RespError) Write(w http.ResponseWriter) {
if statusCode == 0 {
statusCode = http.StatusInternalServerError
}
+ for key, value := range e.ExtraHeader {
+ w.Header().Set(key, value)
+ }
exhttp.WriteJSONResponse(w, statusCode, &e)
}
@@ -180,12 +194,29 @@ func (e RespError) WithStatus(status int) RespError {
return e
}
+func (e RespError) WithCanRetry(canRetry bool) RespError {
+ e.CanRetry = canRetry
+ return e
+}
+
func (e RespError) WithExtraData(extraData map[string]any) RespError {
e.ExtraData = exmaps.NonNilClone(e.ExtraData)
maps.Copy(e.ExtraData, extraData)
return e
}
+func (e RespError) WithExtraHeader(key, value string) RespError {
+ e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader)
+ e.ExtraHeader[key] = value
+ return e
+}
+
+func (e RespError) WithExtraHeaders(headers map[string]string) RespError {
+ e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader)
+ maps.Copy(e.ExtraHeader, headers)
+ return e
+}
+
// Error returns the errcode and error message.
func (e RespError) Error() string {
return e.ErrCode + ": " + e.Err
diff --git a/event/beeper.go b/event/beeper.go
index 95b4a571..a1a60b35 100644
--- a/event/beeper.go
+++ b/event/beeper.go
@@ -53,6 +53,8 @@ type BeeperMessageStatusEventContent struct {
LastRetry id.EventID `json:"last_retry,omitempty"`
+ TargetTxnID string `json:"relates_to_txn_id,omitempty"`
+
MutateEventKey string `json:"mutate_event_key,omitempty"`
// Indicates the set of users to whom the event was delivered. If nil, then
@@ -87,7 +89,19 @@ type BeeperRoomKeyAckEventContent struct {
}
type BeeperChatDeleteEventContent struct {
- DeleteForEveryone bool `json:"delete_for_everyone,omitempty"`
+ DeleteForEveryone bool `json:"delete_for_everyone,omitempty"`
+ FromMessageRequest bool `json:"from_message_request,omitempty"`
+}
+
+type BeeperAcceptMessageRequestEventContent struct {
+ // Whether this was triggered by a message rather than an explicit event
+ IsImplicit bool `json:"-"`
+}
+
+type BeeperSendStateEventContent struct {
+ Type string `json:"type"`
+ StateKey string `json:"state_key"`
+ Content Content `json:"content"`
}
type IntOrString int
@@ -132,6 +146,7 @@ type BeeperLinkPreview struct {
MatchedURL string `json:"matched_url,omitempty"`
ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"`
+ ImageBlurhash string `json:"matrix:image:blurhash,omitempty"`
}
type BeeperProfileExtra struct {
@@ -151,6 +166,24 @@ type BeeperPerMessageProfile struct {
HasFallback bool `json:"has_fallback,omitempty"`
}
+type BeeperActionMessageType string
+
+const (
+ BeeperActionMessageCall BeeperActionMessageType = "call"
+)
+
+type BeeperActionMessageCallType string
+
+const (
+ BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice"
+ BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video"
+)
+
+type BeeperActionMessage struct {
+ Type BeeperActionMessageType `json:"type"`
+ CallType BeeperActionMessageCallType `json:"call_type,omitempty"`
+}
+
func (content *MessageEventContent) AddPerMessageProfileFallback() {
if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" {
return
@@ -181,6 +214,15 @@ func (content *MessageEventContent) RemovePerMessageProfileFallback() {
}
}
+type BeeperAIStreamEventContent struct {
+ TurnID string `json:"turn_id"`
+ Seq int `json:"seq"`
+ Part map[string]any `json:"part"`
+ TargetEvent id.EventID `json:"target_event,omitempty"`
+ AgentID string `json:"agent_id,omitempty"`
+ RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
+}
+
type BeeperEncodedOrder struct {
order int64
suborder int16
diff --git a/event/botcommand.go b/event/botcommand.go
deleted file mode 100644
index 2b208656..00000000
--- a/event/botcommand.go
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright (c) 2025 Tulir Asokan
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this
-// file, You can obtain one at http://mozilla.org/MPL/2.0/.
-
-package event
-
-import (
- "encoding/json"
-)
-
-type BotCommandsEventContent struct {
- Sigil string `json:"sigil,omitempty"`
- Commands []*BotCommand `json:"commands,omitempty"`
-}
-
-type BotCommand struct {
- Syntax string `json:"syntax"`
- Aliases []string `json:"fi.mau.aliases,omitempty"` // Not in MSC (yet)
- Arguments []*BotCommandArgument `json:"arguments,omitempty"`
- Description *ExtensibleTextContainer `json:"description,omitempty"`
-}
-
-type BotArgumentType string
-
-const (
- BotArgumentTypeString BotArgumentType = "string"
- BotArgumentTypeEnum BotArgumentType = "enum"
- BotArgumentTypeInteger BotArgumentType = "integer"
- BotArgumentTypeBoolean BotArgumentType = "boolean"
- BotArgumentTypeUserID BotArgumentType = "user_id"
- BotArgumentTypeRoomID BotArgumentType = "room_id"
- BotArgumentTypeRoomAlias BotArgumentType = "room_alias"
- BotArgumentTypeEventID BotArgumentType = "event_id"
-)
-
-type BotCommandArgument struct {
- Type BotArgumentType `json:"type"`
- DefaultValue any `json:"fi.mau.default_value,omitempty"` // Not in MSC (yet)
- Description *ExtensibleTextContainer `json:"description,omitempty"`
- Enum []string `json:"enum,omitempty"`
- Variadic bool `json:"variadic,omitempty"`
-}
-
-type BotCommandInput struct {
- Syntax string `json:"syntax"`
- Arguments json.RawMessage `json:"arguments,omitempty"`
-}
diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts
index 1fbc9610..26aeb347 100644
--- a/event/capabilities.d.ts
+++ b/event/capabilities.d.ts
@@ -77,6 +77,11 @@ export interface RoomFeatures {
delete_chat?: boolean
/** Whether deleting the chat for all participants is supported. */
delete_chat_for_everyone?: boolean
+ /** What can be done with message requests? */
+ message_request?: {
+ accept_with_message?: CapabilitySupportLevel
+ accept_with_button?: CapabilitySupportLevel
+ }
}
declare type integer = number
diff --git a/event/capabilities.go b/event/capabilities.go
index 4b7ff186..a86c726b 100644
--- a/event/capabilities.go
+++ b/event/capabilities.go
@@ -61,6 +61,8 @@ type RoomFeatures struct {
DeleteChat bool `json:"delete_chat,omitempty"`
DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"`
+ MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"`
+
PerMessageProfileRelay bool `json:"-"`
}
@@ -84,6 +86,7 @@ func (rf *RoomFeatures) Clone() *RoomFeatures {
clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge)
clone.DisappearingTimer = clone.DisappearingTimer.Clone()
clone.AllowedReactions = slices.Clone(clone.AllowedReactions)
+ clone.MessageRequest = clone.MessageRequest.Clone()
return &clone
}
@@ -165,6 +168,25 @@ func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTime
return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer))
}
+type MessageRequestFeatures struct {
+ AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"`
+ AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"`
+}
+
+func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures {
+ return ptr.Clone(mrf)
+}
+
+func (mrf *MessageRequestFeatures) Hash() []byte {
+ if mrf == nil {
+ return nil
+ }
+ hasher := sha256.New()
+ hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage)
+ hashValue(hasher, "accept_with_button", mrf.AcceptWithButton)
+ return hasher.Sum(nil)
+}
+
type CapabilityMsgType = MessageType
// Message types which are used for event capability signaling, but aren't real values for the msgtype field.
@@ -347,6 +369,7 @@ func (rf *RoomFeatures) Hash() []byte {
hashBool(hasher, "mark_as_unread", rf.MarkAsUnread)
hashBool(hasher, "delete_chat", rf.DeleteChat)
hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone)
+ hashValue(hasher, "message_request", rf.MessageRequest)
return hasher.Sum(nil)
}
diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go
new file mode 100644
index 00000000..ce07c4c0
--- /dev/null
+++ b/event/cmdschema/content.go
@@ -0,0 +1,78 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "reflect"
+ "slices"
+
+ "go.mau.fi/util/exsync"
+ "go.mau.fi/util/ptr"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type EventContent struct {
+ Command string `json:"command"`
+ Aliases []string `json:"aliases,omitempty"`
+ Parameters []*Parameter `json:"parameters,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ TailParam string `json:"fi.mau.tail_parameter,omitempty"`
+}
+
+func (ec *EventContent) Validate() error {
+ if ec == nil {
+ return fmt.Errorf("event content is nil")
+ } else if ec.Command == "" {
+ return fmt.Errorf("command is empty")
+ }
+ var tailFound bool
+ dupMap := exsync.NewSet[string]()
+ for i, p := range ec.Parameters {
+ if err := p.Validate(); err != nil {
+ return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err)
+ } else if !dupMap.Add(p.Key) {
+ return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1)
+ } else if p.Key == ec.TailParam {
+ tailFound = true
+ } else if tailFound && !p.Optional {
+ return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam)
+ }
+ }
+ if ec.TailParam != "" && !tailFound {
+ return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam)
+ }
+ return nil
+}
+
+func (ec *EventContent) IsValid() bool {
+ return ec.Validate() == nil
+}
+
+func (ec *EventContent) StateKey(owner id.UserID) string {
+ hash := sha256.Sum256([]byte(ec.Command + owner.String()))
+ return base64.StdEncoding.EncodeToString(hash[:])
+}
+
+func (ec *EventContent) Equals(other *EventContent) bool {
+ if ec == nil || other == nil {
+ return ec == other
+ }
+ return ec.Command == other.Command &&
+ slices.Equal(ec.Aliases, other.Aliases) &&
+ slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) &&
+ ec.Description.Equals(other.Description) &&
+ ec.TailParam == other.TailParam
+}
+
+func init() {
+ event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{})
+}
diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go
new file mode 100644
index 00000000..4193b297
--- /dev/null
+++ b/event/cmdschema/parameter.go
@@ -0,0 +1,286 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+
+ "go.mau.fi/util/exslices"
+
+ "maunium.net/go/mautrix/event"
+)
+
+type Parameter struct {
+ Key string `json:"key"`
+ Schema *ParameterSchema `json:"schema"`
+ Optional bool `json:"optional,omitempty"`
+ Description *event.ExtensibleTextContainer `json:"description,omitempty"`
+ DefaultValue any `json:"fi.mau.default_value,omitempty"`
+}
+
+func (p *Parameter) Equals(other *Parameter) bool {
+ if p == nil || other == nil {
+ return p == other
+ }
+ return p.Key == other.Key &&
+ p.Schema.Equals(other.Schema) &&
+ p.Optional == other.Optional &&
+ p.Description.Equals(other.Description) &&
+ p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values
+}
+
+func (p *Parameter) Validate() error {
+ if p == nil {
+ return fmt.Errorf("parameter is nil")
+ } else if p.Key == "" {
+ return fmt.Errorf("key is empty")
+ }
+ return p.Schema.Validate()
+}
+
+func (p *Parameter) IsValid() bool {
+ return p.Validate() == nil
+}
+
+func (p *Parameter) GetDefaultValue() any {
+ if p != nil && p.DefaultValue != nil {
+ return p.DefaultValue
+ } else if p == nil || p.Optional {
+ return nil
+ }
+ return p.Schema.GetDefaultValue()
+}
+
+type PrimitiveType string
+
+const (
+ PrimitiveTypeString PrimitiveType = "string"
+ PrimitiveTypeInteger PrimitiveType = "integer"
+ PrimitiveTypeBoolean PrimitiveType = "boolean"
+ PrimitiveTypeServerName PrimitiveType = "server_name"
+ PrimitiveTypeUserID PrimitiveType = "user_id"
+ PrimitiveTypeRoomID PrimitiveType = "room_id"
+ PrimitiveTypeRoomAlias PrimitiveType = "room_alias"
+ PrimitiveTypeEventID PrimitiveType = "event_id"
+)
+
+func (pt PrimitiveType) Schema() *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypePrimitive,
+ Type: pt,
+ }
+}
+
+func (pt PrimitiveType) IsValid() bool {
+ switch pt {
+ case PrimitiveTypeString,
+ PrimitiveTypeInteger,
+ PrimitiveTypeBoolean,
+ PrimitiveTypeServerName,
+ PrimitiveTypeUserID,
+ PrimitiveTypeRoomID,
+ PrimitiveTypeRoomAlias,
+ PrimitiveTypeEventID:
+ return true
+ default:
+ return false
+ }
+}
+
+type SchemaType string
+
+const (
+ SchemaTypePrimitive SchemaType = "primitive"
+ SchemaTypeArray SchemaType = "array"
+ SchemaTypeUnion SchemaType = "union"
+ SchemaTypeLiteral SchemaType = "literal"
+)
+
+type ParameterSchema struct {
+ SchemaType SchemaType `json:"schema_type"`
+ Type PrimitiveType `json:"type,omitempty"` // Only for primitive
+ Items *ParameterSchema `json:"items,omitempty"` // Only for array
+ Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union
+ Value any `json:"value,omitempty"` // Only for literal
+}
+
+func Literal(value any) *ParameterSchema {
+ return &ParameterSchema{
+ SchemaType: SchemaTypeLiteral,
+ Value: value,
+ }
+}
+
+func Enum(values ...any) *ParameterSchema {
+ return Union(exslices.CastFunc(values, Literal)...)
+}
+
+func flattenUnion(variants []*ParameterSchema) []*ParameterSchema {
+ var flattened []*ParameterSchema
+ for _, variant := range variants {
+ switch variant.SchemaType {
+ case SchemaTypeArray:
+ panic(fmt.Errorf("illegal array schema in union"))
+ case SchemaTypeUnion:
+ flattened = append(flattened, flattenUnion(variant.Variants)...)
+ default:
+ flattened = append(flattened, variant)
+ }
+ }
+ return flattened
+}
+
+func Union(variants ...*ParameterSchema) *ParameterSchema {
+ needsFlattening := false
+ for _, variant := range variants {
+ if variant.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in union"))
+ } else if variant.SchemaType == SchemaTypeUnion {
+ needsFlattening = true
+ }
+ }
+ if needsFlattening {
+ variants = flattenUnion(variants)
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeUnion,
+ Variants: variants,
+ }
+}
+
+func Array(items *ParameterSchema) *ParameterSchema {
+ if items.SchemaType == SchemaTypeArray {
+ panic(fmt.Errorf("illegal array schema in array"))
+ }
+ return &ParameterSchema{
+ SchemaType: SchemaTypeArray,
+ Items: items,
+ }
+}
+
+func (ps *ParameterSchema) GetDefaultValue() any {
+ if ps == nil {
+ return nil
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ switch ps.Type {
+ case PrimitiveTypeInteger:
+ return 0
+ case PrimitiveTypeBoolean:
+ return false
+ default:
+ return ""
+ }
+ case SchemaTypeArray:
+ return []any{}
+ case SchemaTypeUnion:
+ if len(ps.Variants) > 0 {
+ return ps.Variants[0].GetDefaultValue()
+ }
+ return nil
+ case SchemaTypeLiteral:
+ return ps.Value
+ default:
+ return nil
+ }
+}
+
+func (ps *ParameterSchema) IsValid() bool {
+ return ps.validate("") == nil
+}
+
+func (ps *ParameterSchema) Validate() error {
+ return ps.validate("")
+}
+
+func (ps *ParameterSchema) validate(parent SchemaType) error {
+ if ps == nil {
+ return fmt.Errorf("schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ if !ps.Type.IsValid() {
+ return fmt.Errorf("invalid primitive type %s", ps.Type)
+ } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("primitive schema has extra fields")
+ }
+ return nil
+ case SchemaTypeArray:
+ if parent != "" {
+ return fmt.Errorf("arrays can't be nested in other types")
+ } else if err := ps.Items.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("item schema is invalid: %w", err)
+ } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil {
+ return fmt.Errorf("array schema has extra fields")
+ }
+ return nil
+ case SchemaTypeUnion:
+ if len(ps.Variants) == 0 {
+ return fmt.Errorf("no variants specified for union")
+ } else if parent != "" && parent != SchemaTypeArray {
+ return fmt.Errorf("unions can't be nested in anything other than arrays")
+ }
+ for i, v := range ps.Variants {
+ if err := v.validate(ps.SchemaType); err != nil {
+ return fmt.Errorf("variant #%d is invalid: %w", i+1, err)
+ }
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Value != nil {
+ return fmt.Errorf("union schema has extra fields")
+ }
+ return nil
+ case SchemaTypeLiteral:
+ switch typedVal := ps.Value.(type) {
+ case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue:
+ // ok
+ case map[string]any:
+ if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" {
+ return fmt.Errorf("literal value has invalid map data")
+ }
+ default:
+ return fmt.Errorf("literal value has unsupported type %T", ps.Value)
+ }
+ if ps.Type != "" || ps.Items != nil || ps.Variants != nil {
+ return fmt.Errorf("literal schema has extra fields")
+ }
+ return nil
+ default:
+ return fmt.Errorf("invalid schema type %s", ps.SchemaType)
+ }
+}
+
+func (ps *ParameterSchema) Equals(other *ParameterSchema) bool {
+ if ps == nil || other == nil {
+ return ps == other
+ }
+ return ps.SchemaType == other.SchemaType &&
+ ps.Type == other.Type &&
+ ps.Items.Equals(other.Items) &&
+ slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) &&
+ ps.Value == other.Value // TODO this won't work for room/event ID values
+}
+
+func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool {
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type == prim
+ case SchemaTypeUnion:
+ for _, variant := range ps.Variants {
+ if variant.AllowsPrimitive(prim) {
+ return true
+ }
+ }
+ return false
+ case SchemaTypeArray:
+ return ps.Items.AllowsPrimitive(prim)
+ default:
+ return false
+ }
+}
diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go
new file mode 100644
index 00000000..92e69b60
--- /dev/null
+++ b/event/cmdschema/parse.go
@@ -0,0 +1,478 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+const botArrayOpener = "<"
+const botArrayCloser = ">"
+
+func parseQuoted(val string) (parsed, remaining string, quoted bool) {
+ if len(val) == 0 {
+ return
+ }
+ if !strings.HasPrefix(val, `"`) {
+ spaceIdx := strings.IndexByte(val, ' ')
+ if spaceIdx == -1 {
+ parsed = val
+ } else {
+ parsed = val[:spaceIdx]
+ remaining = strings.TrimLeft(val[spaceIdx+1:], " ")
+ }
+ return
+ }
+ val = val[1:]
+ var buf strings.Builder
+ for {
+ quoteIdx := strings.IndexByte(val, '"')
+ var valUntilQuote string
+ if quoteIdx == -1 {
+ valUntilQuote = val
+ } else {
+ valUntilQuote = val[:quoteIdx]
+ }
+ escapeIdx := strings.IndexByte(valUntilQuote, '\\')
+ if escapeIdx >= 0 {
+ buf.WriteString(val[:escapeIdx])
+ if len(val) > escapeIdx+1 {
+ buf.WriteByte(val[escapeIdx+1])
+ }
+ val = val[min(escapeIdx+2, len(val)):]
+ } else if quoteIdx >= 0 {
+ buf.WriteString(val[:quoteIdx])
+ val = val[quoteIdx+1:]
+ break
+ } else if buf.Len() == 0 {
+ // Unterminated quote, no escape characters, val is the whole input
+ return val, "", true
+ } else {
+ // Unterminated quote, but there were escape characters previously
+ buf.WriteString(val)
+ val = ""
+ break
+ }
+ }
+ return buf.String(), strings.TrimLeft(val, " "), true
+}
+
+// ParseInput tries to parse the given text into a bot command event matching this command definition.
+//
+// If the prefix doesn't match, this will return a nil content and nil error.
+// If the prefix does match, some content is always returned, but there may still be an error if parsing failed.
+func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) {
+ prefix := ec.parsePrefix(input, sigils, owner.String())
+ if prefix == "" {
+ return nil, nil
+ }
+ content = &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: input,
+ Mentions: &event.Mentions{UserIDs: []id.UserID{owner}},
+ MSC4391BotCommand: &event.MSC4391BotCommandInput{
+ Command: ec.Command,
+ },
+ }
+ content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):])
+ return content, err
+}
+
+func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) {
+ args := make(map[string]any)
+ var retErr error
+ setError := func(err error) {
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+ processParameter := func(param *Parameter, isLast, isTail, isNamed bool) {
+ origInput := input
+ var nextVal string
+ var wasQuoted bool
+ if param.Schema.SchemaType == SchemaTypeArray {
+ hasOpener := strings.HasPrefix(input, botArrayOpener)
+ arrayClosed := false
+ if hasOpener {
+ input = input[len(botArrayOpener):]
+ if strings.HasPrefix(input, botArrayCloser) {
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ }
+ }
+ var collector []any
+ for len(input) > 0 && !arrayClosed {
+ //origInput = input
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) {
+ // The value wasn't quoted and has the array delimiter at the end, close the array
+ nextVal = strings.TrimRight(nextVal, botArrayCloser)
+ arrayClosed = true
+ } else if hasOpener && strings.HasPrefix(input, botArrayCloser) {
+ // The value was quoted or there was a space, and the next character is the
+ // array delimiter, close the array
+ input = strings.TrimLeft(input[len(botArrayCloser):], " ")
+ arrayClosed = true
+ } else if !hasOpener && !isLast {
+ // For array arguments in the middle without the <> delimiters, stop after the first item
+ arrayClosed = true
+ }
+ parsedVal, err := param.Schema.Items.ParseString(nextVal)
+ if err == nil {
+ collector = append(collector, parsedVal)
+ } else if hasOpener || isLast {
+ setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err))
+ } else {
+ //input = origInput
+ }
+ }
+ args[param.Key] = collector
+ } else {
+ nextVal, input, wasQuoted = parseQuoted(input)
+ if (isLast || isTail) && !wasQuoted && len(input) > 0 {
+ // If the last argument is not quoted, just treat the rest of the string
+ // as the argument without escapes (arguments with escapes should be quoted).
+ nextVal += " " + input
+ input = ""
+ }
+ // Special case for named boolean parameters: if no value is given, treat it as true
+ if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) {
+ args[param.Key] = true
+ return
+ }
+ if nextVal == "" && !wasQuoted && !isNamed && !param.Optional {
+ setError(fmt.Errorf("missing value for required parameter %s", param.Key))
+ }
+ parsedVal, err := param.Schema.ParseString(nextVal)
+ if err != nil {
+ args[param.Key] = param.GetDefaultValue()
+ // For optional parameters that fail to parse, restore the input and try passing it as the next parameter
+ if param.Optional && !isLast && !isNamed {
+ input = strings.TrimLeft(origInput, " ")
+ } else if !param.Optional || isNamed {
+ setError(fmt.Errorf("failed to parse %s: %w", param.Key, err))
+ }
+ } else {
+ args[param.Key] = parsedVal
+ }
+ }
+ }
+ skipParams := make([]bool, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ for strings.HasPrefix(input, "--") {
+ nameEndIdx := strings.IndexAny(input, " =")
+ if nameEndIdx == -1 {
+ nameEndIdx = len(input)
+ }
+ overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx])
+ if overrideParam != nil {
+ // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input
+ input = strings.TrimPrefix(input[nameEndIdx:], "=")
+ skipParams[paramIdx] = true
+ processParameter(overrideParam, false, false, true)
+ } else {
+ break
+ }
+ }
+ isTail := param.Key == ec.TailParam
+ if skipParams[i] || (param.Optional && !isTail) {
+ continue
+ }
+ processParameter(param, i == len(ec.Parameters)-1, isTail, false)
+ }
+ jsonArgs, marshalErr := json.Marshal(args)
+ if marshalErr != nil {
+ return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr)
+ }
+ return jsonArgs, retErr
+}
+
+func (ec *EventContent) parameterByName(name string) (*Parameter, int) {
+ for i, param := range ec.Parameters {
+ if strings.EqualFold(param.Key, name) {
+ return param, i
+ }
+ }
+ return nil, -1
+}
+
+func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) {
+ input := origInput
+ var chosenSigil string
+ for _, sigil := range sigils {
+ if strings.HasPrefix(input, sigil) {
+ chosenSigil = sigil
+ break
+ }
+ }
+ if chosenSigil == "" {
+ return ""
+ }
+ input = input[len(chosenSigil):]
+ var chosenAlias string
+ if !strings.HasPrefix(input, ec.Command) {
+ for _, alias := range ec.Aliases {
+ if strings.HasPrefix(input, alias) {
+ chosenAlias = alias
+ break
+ }
+ }
+ if chosenAlias == "" {
+ return ""
+ }
+ } else {
+ chosenAlias = ec.Command
+ }
+ input = strings.TrimPrefix(input[len(chosenAlias):], owner)
+ if input == "" || input[0] == ' ' {
+ input = strings.TrimLeft(input, " ")
+ return origInput[:len(origInput)-len(input)]
+ }
+ return ""
+}
+
+func (pt PrimitiveType) ValidateValue(value any) bool {
+ _, err := pt.NormalizeValue(value)
+ return err == nil
+}
+
+func normalizeNumber(value any) (int, error) {
+ switch typedValue := value.(type) {
+ case int:
+ return typedValue, nil
+ case int64:
+ return int(typedValue), nil
+ case float64:
+ return int(typedValue), nil
+ case json.Number:
+ if i, err := typedValue.Int64(); err != nil {
+ return 0, fmt.Errorf("failed to parse json.Number: %w", err)
+ } else {
+ return int(i), nil
+ }
+ default:
+ return 0, fmt.Errorf("unsupported type %T for integer", value)
+ }
+}
+
+func (pt PrimitiveType) NormalizeValue(value any) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return normalizeNumber(value)
+ case PrimitiveTypeBoolean:
+ bv, ok := value.(bool)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for boolean", value)
+ }
+ return bv, nil
+ case PrimitiveTypeString, PrimitiveTypeServerName:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for string", value)
+ }
+ return str, pt.validateStringValue(str)
+ case PrimitiveTypeUserID, PrimitiveTypeRoomAlias:
+ str, ok := value.(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value)
+ } else if plainErr := pt.validateStringValue(str); plainErr == nil {
+ return str, nil
+ } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID {
+ return parsed.UserID(), nil
+ } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias {
+ return parsed.RoomAlias(), nil
+ } else {
+ return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1)
+ }
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ riv, err := NormalizeRoomIDValue(value)
+ if err != nil {
+ return nil, err
+ }
+ return riv, riv.Validate()
+ default:
+ return nil, fmt.Errorf("cannot normalize value for argument type %s", pt)
+ }
+}
+
+func (pt PrimitiveType) validateStringValue(value string) error {
+ switch pt {
+ case PrimitiveTypeString:
+ return nil
+ case PrimitiveTypeServerName:
+ if !id.ValidateServerName(value) {
+ return fmt.Errorf("invalid server name: %q", value)
+ }
+ return nil
+ case PrimitiveTypeUserID:
+ _, _, err := id.UserID(value).ParseAndValidateRelaxed()
+ return err
+ case PrimitiveTypeRoomAlias:
+ sigil, localpart, serverName := id.ParseCommonIdentifier(value)
+ if sigil != '#' || localpart == "" || serverName == "" {
+ return fmt.Errorf("invalid room alias: %q", value)
+ } else if !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name in room alias: %q", serverName)
+ }
+ return nil
+ default:
+ panic(fmt.Errorf("validateStringValue called with invalid type %s", pt))
+ }
+}
+
+func parseBoolean(val string) (bool, error) {
+ if len(val) == 0 {
+ return false, fmt.Errorf("cannot parse empty string as boolean")
+ }
+ switch strings.ToLower(val) {
+ case "t", "true", "y", "yes", "1":
+ return true, nil
+ case "f", "false", "n", "no", "0":
+ return false, nil
+ default:
+ return false, fmt.Errorf("invalid boolean string: %q", val)
+ }
+}
+
+var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`)
+
+func parseRoomOrEventID(value string) (*RoomIDValue, error) {
+ if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") {
+ matches := markdownLinkRegex.FindStringSubmatch(value)
+ if len(matches) == 2 {
+ value = matches[1]
+ }
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil && strings.HasPrefix(value, "!") {
+ return &RoomIDValue{
+ Type: PrimitiveTypeRoomID,
+ RoomID: id.RoomID(value),
+ }, nil
+ }
+ if err != nil {
+ return nil, err
+ } else if parsed.Sigil1 != '!' {
+ return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1)
+ } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' {
+ return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2)
+ }
+ valType := PrimitiveTypeRoomID
+ if parsed.MXID2 != "" {
+ valType = PrimitiveTypeEventID
+ }
+ return &RoomIDValue{
+ Type: valType,
+ RoomID: parsed.RoomID(),
+ Via: parsed.Via,
+ EventID: parsed.EventID(),
+ }, nil
+}
+
+func (pt PrimitiveType) ParseString(value string) (any, error) {
+ switch pt {
+ case PrimitiveTypeInteger:
+ return strconv.Atoi(value)
+ case PrimitiveTypeBoolean:
+ return parseBoolean(value)
+ case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID:
+ return value, pt.validateStringValue(value)
+ case PrimitiveTypeRoomAlias:
+ plainErr := pt.validateStringValue(value)
+ if plainErr == nil {
+ return value, nil
+ }
+ parsed, err := id.ParseMatrixURIOrMatrixToURL(value)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err)
+ } else if parsed.Sigil1 != '#' {
+ return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1)
+ }
+ return parsed.RoomAlias(), nil
+ case PrimitiveTypeRoomID, PrimitiveTypeEventID:
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, err
+ } else if pt != parsed.Type {
+ return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("cannot parse string for argument type %s", pt)
+ }
+}
+
+func (ps *ParameterSchema) ParseString(value string) (any, error) {
+ if ps == nil {
+ return nil, fmt.Errorf("parameter schema is nil")
+ }
+ switch ps.SchemaType {
+ case SchemaTypePrimitive:
+ return ps.Type.ParseString(value)
+ case SchemaTypeLiteral:
+ switch typedValue := ps.Value.(type) {
+ case string:
+ if value == typedValue {
+ return typedValue, nil
+ } else {
+ return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value)
+ }
+ case int, int64, float64, json.Number:
+ expectedVal, _ := normalizeNumber(typedValue)
+ intVal, err := strconv.Atoi(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse integer literal: %w", err)
+ } else if intVal != expectedVal {
+ return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal)
+ }
+ return intVal, nil
+ case bool:
+ boolVal, err := parseBoolean(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse boolean literal: %w", err)
+ } else if boolVal != typedValue {
+ return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal)
+ }
+ return boolVal, nil
+ case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage:
+ expectedVal, _ := NormalizeRoomIDValue(typedValue)
+ parsed, err := parseRoomOrEventID(value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err)
+ } else if !parsed.Equals(expectedVal) {
+ return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed)
+ }
+ return parsed, nil
+ default:
+ return nil, fmt.Errorf("unsupported literal type %T", ps.Value)
+ }
+ case SchemaTypeUnion:
+ var errs []error
+ for _, variant := range ps.Variants {
+ if parsed, err := variant.ParseString(value); err == nil {
+ return parsed, nil
+ } else {
+ errs = append(errs, err)
+ }
+ }
+ return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...))
+ case SchemaTypeArray:
+ return nil, fmt.Errorf("cannot parse string for array schema type")
+ default:
+ return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType)
+ }
+}
diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go
new file mode 100644
index 00000000..1e0d1817
--- /dev/null
+++ b/event/cmdschema/parse_test.go
@@ -0,0 +1,118 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "bytes"
+ "encoding/json"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "go.mau.fi/util/exbytes"
+ "go.mau.fi/util/exerrors"
+
+ "maunium.net/go/mautrix/event/cmdschema/testdata"
+)
+
+type QuoteParseOutput struct {
+ Parsed string
+ Remaining string
+ Quoted bool
+}
+
+func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error {
+ var arr []any
+ if err := json.Unmarshal(data, &arr); err != nil {
+ return err
+ }
+ qpo.Parsed = arr[0].(string)
+ qpo.Remaining = arr[1].(string)
+ qpo.Quoted = arr[2].(bool)
+ return nil
+}
+
+type QuoteParseTestData struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Output QuoteParseOutput `json:"output"`
+}
+
+func loadFile[T any](name string) (into T) {
+ quoteData := exerrors.Must(testdata.FS.ReadFile(name))
+ exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into))
+ return
+}
+
+func TestParseQuoted(t *testing.T) {
+ qptd := loadFile[[]QuoteParseTestData]("parse_quote.json")
+ for _, test := range qptd {
+ t.Run(test.Name, func(t *testing.T) {
+ parsed, remaining, quoted := parseQuoted(test.Input)
+ assert.Equalf(t, test.Output, QuoteParseOutput{
+ Parsed: parsed,
+ Remaining: remaining,
+ Quoted: quoted,
+ }, "Failed with input `%s`", test.Input)
+ // Note: can't just test that requoted == input, because some inputs
+ // have unnecessary escapes which won't survive roundtripping
+ t.Run("roundtrip", func(t *testing.T) {
+ requoted := quoteString(parsed) + " " + remaining
+ reparsed, newRemaining, _ := parseQuoted(requoted)
+ assert.Equal(t, parsed, reparsed)
+ assert.Equal(t, remaining, newRemaining)
+ })
+ })
+ }
+}
+
+type CommandTestData struct {
+ Spec *EventContent
+ Tests []*CommandTestUnit
+}
+
+type CommandTestUnit struct {
+ Name string `json:"name"`
+ Input string `json:"input"`
+ Broken string `json:"broken,omitempty"`
+ Error bool `json:"error"`
+ Output json.RawMessage `json:"output"`
+}
+
+func compactJSON(input json.RawMessage) json.RawMessage {
+ var buf bytes.Buffer
+ exerrors.PanicIfNotNil(json.Compact(&buf, input))
+ return buf.Bytes()
+}
+
+func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) {
+ for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) {
+ t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) {
+ ctd := loadFile[CommandTestData]("commands/" + cmd.Name())
+ for _, test := range ctd.Tests {
+ outputStr := exbytes.UnsafeString(compactJSON(test.Output))
+ t.Run(test.Name, func(t *testing.T) {
+ if test.Broken != "" {
+ t.Skip(test.Broken)
+ }
+ output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input)
+ if test.Error {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ if outputStr == "null" {
+ assert.Nil(t, output)
+ } else {
+ assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command)
+ assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go
new file mode 100644
index 00000000..98c421fc
--- /dev/null
+++ b/event/cmdschema/roomid.go
@@ -0,0 +1,135 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "strings"
+
+ "maunium.net/go/mautrix/id"
+)
+
+var ParameterSchemaJoinableRoom = Union(
+ PrimitiveTypeRoomID.Schema(),
+ PrimitiveTypeRoomAlias.Schema(),
+)
+
+type RoomIDValue struct {
+ Type PrimitiveType `json:"type"`
+ RoomID id.RoomID `json:"id"`
+ Via []string `json:"via,omitempty"`
+ EventID id.EventID `json:"event_id,omitempty"`
+}
+
+func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) {
+ switch typedValue := input.(type) {
+ case map[string]any, json.RawMessage:
+ var raw json.RawMessage
+ if raw, err = json.Marshal(input); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ } else if err = json.Unmarshal(raw, &riv); err != nil {
+ err = fmt.Errorf("failed to roundtrip room ID value: %w", err)
+ }
+ case *RoomIDValue:
+ riv = typedValue
+ case RoomIDValue:
+ riv = &typedValue
+ default:
+ err = fmt.Errorf("unsupported type %T for room or event ID", input)
+ }
+ return
+}
+
+func (riv *RoomIDValue) String() string {
+ return riv.URI().String()
+}
+
+func (riv *RoomIDValue) URI() *id.MatrixURI {
+ if riv == nil {
+ return nil
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ return riv.RoomID.URI(riv.Via...)
+ case PrimitiveTypeEventID:
+ return riv.RoomID.EventURI(riv.EventID, riv.Via...)
+ default:
+ return nil
+ }
+}
+
+func (riv *RoomIDValue) Equals(other *RoomIDValue) bool {
+ if riv == nil || other == nil {
+ return riv == other
+ }
+ return riv.Type == other.Type &&
+ riv.RoomID == other.RoomID &&
+ riv.EventID == other.EventID &&
+ slices.Equal(riv.Via, other.Via)
+}
+
+func (riv *RoomIDValue) Validate() error {
+ if riv == nil {
+ return fmt.Errorf("value is nil")
+ }
+ switch riv.Type {
+ case PrimitiveTypeRoomID:
+ if riv.EventID != "" {
+ return fmt.Errorf("event ID must be empty for room ID type")
+ }
+ case PrimitiveTypeEventID:
+ if !strings.HasPrefix(riv.EventID.String(), "$") {
+ return fmt.Errorf("event ID not valid: %q", riv.EventID)
+ }
+ default:
+ return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type)
+ }
+ for _, via := range riv.Via {
+ if !id.ValidateServerName(via) {
+ return fmt.Errorf("invalid server name %q in vias", via)
+ }
+ }
+ sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID)
+ if sigil != '!' {
+ return fmt.Errorf("room ID does not start with !: %q", riv.RoomID)
+ } else if localpart == "" && serverName == "" {
+ return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID)
+ } else if serverName != "" && !id.ValidateServerName(serverName) {
+ return fmt.Errorf("invalid server name %q in room ID", serverName)
+ }
+ return nil
+}
+
+func (riv *RoomIDValue) IsValid() bool {
+ return riv.Validate() == nil
+}
+
+type RoomIDOrString string
+
+func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error {
+ if len(data) == 0 {
+ return fmt.Errorf("empty data for room ID or string")
+ }
+ if data[0] == '"' {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(str)
+ return nil
+ }
+ var riv RoomIDValue
+ if err := json.Unmarshal(data, &riv); err != nil {
+ return err
+ } else if err = riv.Validate(); err != nil {
+ return err
+ }
+ *ros = RoomIDOrString(riv.String())
+ return nil
+}
diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go
new file mode 100644
index 00000000..c5c57c53
--- /dev/null
+++ b/event/cmdschema/stringify.go
@@ -0,0 +1,122 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package cmdschema
+
+import (
+ "encoding/json"
+ "strconv"
+ "strings"
+)
+
+var quoteEscaper = strings.NewReplacer(
+ `"`, `\"`,
+ `\`, `\\`,
+)
+
+const charsToQuote = ` \` + botArrayOpener + botArrayCloser
+
+func quoteString(val string) string {
+ if val == "" {
+ return `""`
+ }
+ val = quoteEscaper.Replace(val)
+ if strings.ContainsAny(val, charsToQuote) {
+ return `"` + val + `"`
+ }
+ return val
+}
+
+func (ec *EventContent) StringifyArgs(args any) string {
+ var argMap map[string]any
+ switch typedArgs := args.(type) {
+ case json.RawMessage:
+ err := json.Unmarshal(typedArgs, &argMap)
+ if err != nil {
+ return ""
+ }
+ case map[string]any:
+ argMap = typedArgs
+ default:
+ if b, err := json.Marshal(args); err != nil {
+ return ""
+ } else if err = json.Unmarshal(b, &argMap); err != nil {
+ return ""
+ }
+ }
+ parts := make([]string, 0, len(ec.Parameters))
+ for i, param := range ec.Parameters {
+ isLast := i == len(ec.Parameters)-1
+ val := argMap[param.Key]
+ if val == nil {
+ val = param.DefaultValue
+ if val == nil && !param.Optional {
+ val = param.Schema.GetDefaultValue()
+ }
+ }
+ if val == nil {
+ continue
+ }
+ var stringified string
+ if param.Schema.SchemaType == SchemaTypeArray {
+ stringified = arrayArgumentToString(val, isLast)
+ } else {
+ stringified = singleArgumentToString(val)
+ }
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ return strings.Join(parts, " ")
+}
+
+func arrayArgumentToString(val any, isLast bool) string {
+ valArr, ok := val.([]any)
+ if !ok {
+ return ""
+ }
+ parts := make([]string, 0, len(valArr))
+ for _, elem := range valArr {
+ stringified := singleArgumentToString(elem)
+ if stringified != "" {
+ parts = append(parts, stringified)
+ }
+ }
+ joinedParts := strings.Join(parts, " ")
+ if isLast && len(parts) > 0 {
+ return joinedParts
+ }
+ return botArrayOpener + joinedParts + botArrayCloser
+}
+
+func singleArgumentToString(val any) string {
+ switch typedVal := val.(type) {
+ case string:
+ return quoteString(typedVal)
+ case json.Number:
+ return typedVal.String()
+ case bool:
+ return strconv.FormatBool(typedVal)
+ case int:
+ return strconv.Itoa(typedVal)
+ case int64:
+ return strconv.FormatInt(typedVal, 10)
+ case float64:
+ return strconv.FormatInt(int64(typedVal), 10)
+ case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue:
+ normalized, err := NormalizeRoomIDValue(typedVal)
+ if err != nil {
+ return ""
+ }
+ uri := normalized.URI()
+ if uri == nil {
+ return ""
+ }
+ return quoteString(uri.String())
+ default:
+ return ""
+ }
+}
diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json
new file mode 100644
index 00000000..e53382db
--- /dev/null
+++ b/event/cmdschema/testdata/commands.schema.json
@@ -0,0 +1,281 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "commands.schema.json",
+ "title": "ParseInput test cases",
+ "description": "JSON schema for test case files containing command specifications and test cases",
+ "type": "object",
+ "required": [
+ "spec",
+ "tests"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "spec": {
+ "title": "MSC4391 Command Description",
+ "description": "JSON schema defining the structure of a bot command event content",
+ "type": "object",
+ "required": [
+ "command"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "The command name that triggers this bot command"
+ },
+ "aliases": {
+ "type": "array",
+ "description": "Alternative names/aliases for this command",
+ "items": {
+ "type": "string"
+ }
+ },
+ "parameters": {
+ "type": "array",
+ "description": "List of parameters accepted by this command",
+ "items": {
+ "$ref": "#/$defs/Parameter"
+ }
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of the command"
+ },
+ "fi.mau.tail_parameter": {
+ "type": "string",
+ "description": "The key of the parameter that accepts remaining arguments as tail text"
+ },
+ "source": {
+ "type": "string",
+ "description": "The user ID of the bot that responds to this command"
+ }
+ }
+ },
+ "tests": {
+ "type": "array",
+ "description": "Array of test cases for the command",
+ "items": {
+ "type": "object",
+ "description": "A single test case for command parsing",
+ "required": [
+ "name",
+ "input"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "The name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "The command input string to parse"
+ },
+ "output": {
+ "description": "The expected parsed parameter values, or null if the parsing is expected to fail",
+ "oneOf": [
+ {
+ "type": "object",
+ "additionalProperties": true
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
+ "error": {
+ "type": "boolean",
+ "description": "Whether parsing should result in an error. May still produce output.",
+ "default": false
+ }
+ }
+ }
+ }
+ },
+ "$defs": {
+ "ExtensibleTextContainer": {
+ "type": "object",
+ "description": "Container for text that can have multiple representations",
+ "required": [
+ "m.text"
+ ],
+ "properties": {
+ "m.text": {
+ "type": "array",
+ "description": "Array of text representations in different formats",
+ "items": {
+ "$ref": "#/$defs/ExtensibleText"
+ }
+ }
+ }
+ },
+ "ExtensibleText": {
+ "type": "object",
+ "description": "A text representation with a specific MIME type",
+ "required": [
+ "body"
+ ],
+ "properties": {
+ "body": {
+ "type": "string",
+ "description": "The text content"
+ },
+ "mimetype": {
+ "type": "string",
+ "description": "The MIME type of the text (e.g., text/plain, text/html)",
+ "default": "text/plain",
+ "examples": [
+ "text/plain",
+ "text/html"
+ ]
+ }
+ }
+ },
+ "Parameter": {
+ "type": "object",
+ "description": "A parameter definition for a command",
+ "required": [
+ "key",
+ "schema"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "The identifier for this parameter"
+ },
+ "schema": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema defining the type and structure of this parameter"
+ },
+ "optional": {
+ "type": "boolean",
+ "description": "Whether this parameter is optional",
+ "default": false
+ },
+ "description": {
+ "$ref": "#/$defs/ExtensibleTextContainer",
+ "description": "Human-readable description of this parameter"
+ },
+ "fi.mau.default_value": {
+ "description": "Default value for this parameter if not provided"
+ }
+ }
+ },
+ "ParameterSchema": {
+ "type": "object",
+ "description": "Schema definition for a parameter value",
+ "required": [
+ "schema_type"
+ ],
+ "additionalProperties": false,
+ "properties": {
+ "schema_type": {
+ "type": "string",
+ "enum": [
+ "primitive",
+ "array",
+ "union",
+ "literal"
+ ],
+ "description": "The type of schema"
+ }
+ },
+ "allOf": [
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "primitive"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "type"
+ ],
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": [
+ "string",
+ "integer",
+ "boolean",
+ "server_name",
+ "user_id",
+ "room_id",
+ "room_alias",
+ "event_id"
+ ],
+ "description": "The primitive type (only for schema_type: primitive)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "array"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "items"
+ ],
+ "properties": {
+ "items": {
+ "$ref": "#/$defs/ParameterSchema",
+ "description": "The schema for array items (only for schema_type: array)"
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "union"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "variants"
+ ],
+ "properties": {
+ "variants": {
+ "type": "array",
+ "description": "The possible variants (only for schema_type: union)",
+ "items": {
+ "$ref": "#/$defs/ParameterSchema"
+ },
+ "minItems": 1
+ }
+ }
+ }
+ },
+ {
+ "if": {
+ "properties": {
+ "schema_type": {
+ "const": "literal"
+ }
+ }
+ },
+ "then": {
+ "required": [
+ "value"
+ ],
+ "properties": {
+ "value": {
+ "description": "The literal value (only for schema_type: literal)"
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+}
diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json
new file mode 100644
index 00000000..6ce1f4da
--- /dev/null
+++ b/event/cmdschema/testdata/commands/flags.json
@@ -0,0 +1,126 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "flag",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "user",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "user_id"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true,
+ "fi.mau.default_value": false
+ }
+ ],
+ "fi.mau.tail_parameter": "user"
+ },
+ "tests": [
+ {
+ "name": "no flags",
+ "input": "/flag mrrp",
+ "output": {
+ "meow": "mrrp",
+ "user": null
+ }
+ },
+ {
+ "name": "no flags, has tail",
+ "input": "/flag mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com"
+ }
+ },
+ {
+ "name": "named flag at start",
+ "input": "/flag --woof=yes mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "boolean flag without value",
+ "input": "/flag --woof mrrp @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "user id flag without value",
+ "input": "/flag --user --woof mrrp",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle",
+ "input": "/flag mrrp --woof=yes @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "named flag in the middle with different value",
+ "input": "/flag mrrp --woof=no @user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named",
+ "input": "/flag --woof=no --meow=mrrp --user=@user:example.com",
+ "output": {
+ "meow": "mrrp",
+ "user": "@user:example.com",
+ "woof": false
+ }
+ },
+ {
+ "name": "all variables named with quotes",
+ "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"",
+ "output": {
+ "meow": "meow meow mrrp",
+ "user": "@user:example.com",
+ "woof": true
+ }
+ },
+ {
+ "name": "invalid value for named parameter",
+ "input": "/flag --user=meowings mrrp --woof",
+ "error": true,
+ "output": {
+ "meow": "mrrp",
+ "user": null,
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json
new file mode 100644
index 00000000..1351c292
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_id_or_alias.json
@@ -0,0 +1,85 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "room",
+ "schema": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "room": "#test:matrix.org"
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix.to link with url encoding",
+ "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net",
+ "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!#test/room\nversion 11, with @🐈️:maunium.net",
+ "via": [
+ "maunium.net"
+ ]
+ }
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "room": {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json
new file mode 100644
index 00000000..aa266054
--- /dev/null
+++ b/event/cmdschema/testdata/commands/room_reference_list.json
@@ -0,0 +1,106 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test room reference",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "rooms",
+ "schema": {
+ "schema_type": "array",
+ "items": {
+ "schema_type": "union",
+ "variants": [
+ {
+ "schema_type": "primitive",
+ "type": "room_id"
+ },
+ {
+ "schema_type": "primitive",
+ "type": "room_alias"
+ }
+ ]
+ }
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "room alias",
+ "input": "/test room reference #test:matrix.org",
+ "output": {
+ "rooms": [
+ "#test:matrix.org"
+ ]
+ }
+ },
+ {
+ "name": "room id",
+ "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "two room ids",
+ "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"
+ },
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org"
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI",
+ "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "name": "room id matrix: URI and matrix.to URL",
+ "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org",
+ "output": {
+ "rooms": [
+ {
+ "type": "room_id",
+ "id": "!aiwVrNhPwbGBNjqlNu:matrix.org",
+ "via": [
+ "example.com"
+ ]
+ },
+ {
+ "type": "room_id",
+ "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ",
+ "via": [
+ "maunium.net",
+ "matrix.org"
+ ]
+ }
+ ]
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json
new file mode 100644
index 00000000..94667323
--- /dev/null
+++ b/event/cmdschema/testdata/commands/simple.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "test simple",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ }
+ ]
+ },
+ "tests": [
+ {
+ "name": "success",
+ "input": "/test simple mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "directed success",
+ "input": "/test simple@testbot mrrp",
+ "output": {
+ "meow": "mrrp"
+ }
+ },
+ {
+ "name": "missing parameter",
+ "input": "/test simple",
+ "error": true,
+ "output": {
+ "meow": ""
+ }
+ },
+ {
+ "name": "directed at another bot",
+ "input": "/test simple@anotherbot mrrp",
+ "error": false,
+ "output": null
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json
new file mode 100644
index 00000000..9782f8ec
--- /dev/null
+++ b/event/cmdschema/testdata/commands/tail.json
@@ -0,0 +1,60 @@
+{
+ "$schema": "../commands.schema.json#",
+ "spec": {
+ "command": "tail",
+ "source": "@testbot",
+ "parameters": [
+ {
+ "key": "meow",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ }
+ },
+ {
+ "key": "reason",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "string"
+ },
+ "optional": true
+ },
+ {
+ "key": "woof",
+ "schema": {
+ "schema_type": "primitive",
+ "type": "boolean"
+ },
+ "optional": true
+ }
+ ],
+ "fi.mau.tail_parameter": "reason"
+ },
+ "tests": [
+ {
+ "name": "no tail or flag",
+ "input": "/tail mrrp",
+ "output": {
+ "meow": "mrrp",
+ "reason": ""
+ }
+ },
+ {
+ "name": "tail, no flag",
+ "input": "/tail mrrp meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow"
+ }
+ },
+ {
+ "name": "flag before tail",
+ "input": "/tail mrrp --woof meow meow",
+ "output": {
+ "meow": "mrrp",
+ "reason": "meow meow",
+ "woof": true
+ }
+ }
+ ]
+}
diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go
new file mode 100644
index 00000000..eceea3d2
--- /dev/null
+++ b/event/cmdschema/testdata/data.go
@@ -0,0 +1,14 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package testdata
+
+import (
+ "embed"
+)
+
+//go:embed *
+var FS embed.FS
diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json
new file mode 100644
index 00000000..8f52b7f5
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.json
@@ -0,0 +1,30 @@
+[
+ {"name": "empty string", "input": "", "output": ["", "", false]},
+ {"name": "single word", "input": "meow", "output": ["meow", "", false]},
+ {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]},
+ {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]},
+ {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]},
+ {"name": "only spaces", "input": " ", "output": ["", "", false]},
+ {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]},
+ {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]},
+ {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]},
+ {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]},
+ {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]},
+ {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]},
+ {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]},
+ {"name": "just opening quote", "input": "\"", "output": ["", "", true]},
+ {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]},
+ {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]},
+ {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]},
+ {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]},
+ {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]},
+ {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]},
+ {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]},
+ {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]},
+ {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]},
+ {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]},
+ {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]},
+ {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]}
+]
diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json
new file mode 100644
index 00000000..9f249116
--- /dev/null
+++ b/event/cmdschema/testdata/parse_quote.schema.json
@@ -0,0 +1,46 @@
+{
+ "$schema": "https://json-schema.org/draft/2020-12/schema#",
+ "$id": "parse_quote.schema.json",
+ "title": "parseQuote test cases",
+ "description": "Test cases for the parseQuoted function",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": [
+ "name",
+ "input",
+ "output"
+ ],
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "Name of the test case"
+ },
+ "input": {
+ "type": "string",
+ "description": "Input string to be parsed"
+ },
+ "output": {
+ "type": "array",
+ "description": "Expected output of parsing: [first word, remaining text, was quoted]",
+ "minItems": 3,
+ "maxItems": 3,
+ "prefixItems": [
+ {
+ "type": "string",
+ "description": "First parsed word"
+ },
+ {
+ "type": "string",
+ "description": "Remaining text after the first word"
+ },
+ {
+ "type": "boolean",
+ "description": "Whether the first word was quoted"
+ }
+ ]
+ }
+ },
+ "additionalProperties": false
+ }
+}
diff --git a/event/content.go b/event/content.go
index c0ff51ad..814aeec4 100644
--- a/event/content.go
+++ b/event/content.go
@@ -40,6 +40,9 @@ var TypeMap = map[Type]reflect.Type{
StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}),
StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}),
+ StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+ StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}),
+
StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}),
StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}),
StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}),
@@ -50,7 +53,6 @@ var TypeMap = map[Type]reflect.Type{
StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}),
StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}),
StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}),
- StateBotCommands: reflect.TypeOf(BotCommandsEventContent{}),
EventMessage: reflect.TypeOf(MessageEventContent{}),
EventSticker: reflect.TypeOf(MessageEventContent{}),
@@ -61,9 +63,11 @@ var TypeMap = map[Type]reflect.Type{
EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}),
EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}),
- BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
- BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
- BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
+ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}),
+ BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}),
+ BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}),
+ BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}),
+ BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}),
AccountDataRoomTags: reflect.TypeOf(TagEventContent{}),
AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}),
@@ -72,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{
AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}),
AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}),
- EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
- EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
- EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}),
+ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}),
+ EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}),
+ EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}),
+ BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}),
InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}),
InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}),
diff --git a/event/encryption.go b/event/encryption.go
index cf9c2814..c60cb91a 100644
--- a/event/encryption.go
+++ b/event/encryption.go
@@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error {
return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext)
case id.AlgorithmMegolmV1:
if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' {
- return id.InputNotJSONString
+ return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString)
}
content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1]
}
@@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct {
type RequestedKeyInfo struct {
Algorithm id.Algorithm `json:"algorithm"`
RoomID id.RoomID `json:"room_id"`
- SenderKey id.SenderKey `json:"sender_key"`
SessionID id.SessionID `json:"session_id"`
+ // Deprecated: Matrix v1.3
+ SenderKey id.SenderKey `json:"sender_key"`
}
type RoomKeyWithheldCode string
diff --git a/event/message.go b/event/message.go
index 692382cf..3fb3dc82 100644
--- a/event/message.go
+++ b/event/message.go
@@ -135,6 +135,7 @@ type MessageEventContent struct {
BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"`
+ BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"`
BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"`
@@ -143,7 +144,7 @@ type MessageEventContent struct {
MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"`
MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"`
- MSC4332BotCommand *BotCommandInput `json:"org.matrix.msc4332.command,omitempty"`
+ MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"`
}
func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType {
@@ -287,6 +288,13 @@ func (m *Mentions) Merge(other *Mentions) *Mentions {
}
}
+type MSC4391BotCommandInputCustom[T any] struct {
+ Command string `json:"command"`
+ Arguments T `json:"arguments,omitempty"`
+}
+
+type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage]
+
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`
diff --git a/event/powerlevels.go b/event/powerlevels.go
index 708721f9..668eb6d3 100644
--- a/event/powerlevels.go
+++ b/event/powerlevels.go
@@ -28,6 +28,9 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
+ beeperEphemeralLock sync.RWMutex
+ BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"`
+
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
@@ -37,6 +40,8 @@ type PowerLevelsEventContent struct {
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
+ BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"`
+
// This is not a part of power levels, it's added by mautrix-go internally in certain places
// in order to detect creator power accurately.
CreateEvent *Event `json:"-"`
@@ -51,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
UsersDefault: pl.UsersDefault,
Events: maps.Clone(pl.Events),
EventsDefault: pl.EventsDefault,
+ BeeperEphemeral: maps.Clone(pl.BeeperEphemeral),
StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
@@ -60,6 +66,8 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
BanPtr: ptr.Clone(pl.BanPtr),
RedactPtr: ptr.Clone(pl.RedactPtr),
+ BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr),
+
CreateEvent: pl.CreateEvent,
}
}
@@ -119,6 +127,13 @@ func (pl *PowerLevelsEventContent) StateDefault() int {
return 50
}
+func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int {
+ if pl.BeeperEphemeralDefaultPtr != nil {
+ return *pl.BeeperEphemeralDefaultPtr
+ }
+ return pl.EventsDefault
+}
+
func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int {
if pl.isCreator(userID) {
return math.MaxInt
@@ -202,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int {
return level
}
+func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int {
+ pl.beeperEphemeralLock.RLock()
+ defer pl.beeperEphemeralLock.RUnlock()
+ level, ok := pl.BeeperEphemeral[eventType.String()]
+ if !ok {
+ return pl.BeeperEphemeralDefault()
+ }
+ return level
+}
+
+func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) {
+ pl.beeperEphemeralLock.Lock()
+ defer pl.beeperEphemeralLock.Unlock()
+ if level == pl.BeeperEphemeralDefault() {
+ delete(pl.BeeperEphemeral, eventType.String())
+ } else {
+ if pl.BeeperEphemeral == nil {
+ pl.BeeperEphemeral = make(map[string]int)
+ }
+ pl.BeeperEphemeral[eventType.String()] = level
+ }
+}
+
func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) {
pl.eventsLock.Lock()
defer pl.eventsLock.Unlock()
diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go
new file mode 100644
index 00000000..f5861583
--- /dev/null
+++ b/event/powerlevels_ephemeral_test.go
@@ -0,0 +1,67 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package event_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "maunium.net/go/mautrix/event"
+)
+
+func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 45,
+ }
+
+ assert.Equal(t, 45, pl.BeeperEphemeralDefault())
+
+ override := 60
+ pl.BeeperEphemeralDefaultPtr = &override
+ assert.Equal(t, 60, pl.BeeperEphemeralDefault())
+}
+
+func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) {
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 25,
+ }
+ evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
+
+ assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType))
+
+ pl.SetBeeperEphemeralLevel(evtType, 50)
+ assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType))
+ require.NotNil(t, pl.BeeperEphemeral)
+ assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()])
+
+ pl.SetBeeperEphemeralLevel(evtType, 25)
+ _, exists := pl.BeeperEphemeral[evtType.String()]
+ assert.False(t, exists)
+}
+
+func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) {
+ override := 70
+ pl := &event.PowerLevelsEventContent{
+ EventsDefault: 35,
+ BeeperEphemeral: map[string]int{"com.example.ephemeral": 90},
+ BeeperEphemeralDefaultPtr: &override,
+ }
+
+ cloned := pl.Clone()
+ require.NotNil(t, cloned)
+ require.NotNil(t, cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr)
+ assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"])
+
+ cloned.BeeperEphemeral["com.example.ephemeral"] = 99
+ *cloned.BeeperEphemeralDefaultPtr = 71
+
+ assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"])
+ assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr)
+}
diff --git a/event/state.go b/event/state.go
index 6df3b143..ace170a5 100644
--- a/event/state.go
+++ b/event/state.go
@@ -62,6 +62,13 @@ type ExtensibleTextContainer struct {
Text []ExtensibleText `json:"m.text"`
}
+func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool {
+ if c == nil || description == nil {
+ return c == description
+ }
+ return slices.Equal(c.Text, description.Text)
+}
+
func MakeExtensibleText(text string) *ExtensibleTextContainer {
return &ExtensibleTextContainer{
Text: []ExtensibleText{{
@@ -231,7 +238,8 @@ type BridgeInfoSection struct {
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
ExternalURL string `json:"external_url,omitempty"`
- Receiver string `json:"fi.mau.receiver,omitempty"`
+ Receiver string `json:"fi.mau.receiver,omitempty"`
+ MessageRequest bool `json:"com.beeper.message_request,omitempty"`
}
// BridgeEventContent represents the content of a m.bridge state event.
@@ -335,3 +343,15 @@ func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool {
efmc.ServiceMembers = append(efmc.ServiceMembers, mxid)
return true
}
+
+type PolicyServerPublicKeys struct {
+ Ed25519 id.Ed25519 `json:"ed25519,omitempty"`
+}
+
+type RoomPolicyEventContent struct {
+ Via string `json:"via,omitempty"`
+ PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"`
+
+ // Deprecated, only for legacy use
+ PublicKey id.Ed25519 `json:"public_key,omitempty"`
+}
diff --git a/event/type.go b/event/type.go
index 56ea82f6..80b86728 100644
--- a/event/type.go
+++ b/event/type.go
@@ -113,9 +113,9 @@ func (et *Type) GuessClass() TypeClass {
StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type,
StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type,
StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type,
- StateBotCommands.Type:
+ StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type:
return StateEventType
- case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type:
+ case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type:
return EphemeralEventType
case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type,
AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type,
@@ -128,7 +128,7 @@ func (et *Type) GuessClass() TypeClass {
InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type,
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type,
- EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type:
+ EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
@@ -195,6 +195,9 @@ var (
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
+ StateRoomPolicy = Type{"m.room.policy", StateEventType}
+ StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType}
+
StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType}
StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType}
StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType}
@@ -205,7 +208,7 @@ var (
StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType}
StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType}
StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType}
- StateBotCommands = Type{"org.matrix.msc4332.commands", StateEventType}
+ StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType}
)
// Message events
@@ -234,9 +237,11 @@ var (
CallNegotiate = Type{"m.call.negotiate", MessageEventType}
CallHangup = Type{"m.call.hangup", MessageEventType}
- BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
- BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
- BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
+ BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType}
+ BeeperTranscription = Type{"com.beeper.transcription", MessageEventType}
+ BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType}
+ BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType}
+ BeeperSendState = Type{"com.beeper.send_state", MessageEventType}
EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType}
EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType}
@@ -245,9 +250,11 @@ var (
// Ephemeral events
var (
- EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
- EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
- EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType}
+ EphemeralEventTyping = Type{"m.typing", EphemeralEventType}
+ EphemeralEventPresence = Type{"m.presence", EphemeralEventType}
+ EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType}
+ BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType}
)
// Account data events
diff --git a/federation/client.go b/federation/client.go
index b24fd2d2..183fb5d1 100644
--- a/federation/client.go
+++ b/federation/client.go
@@ -263,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken
return
}
+type ReqMakeJoin struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+ SupportedVersions []id.RoomVersion
+}
+
+type RespMakeJoin struct {
+ RoomVersion id.RoomVersion `json:"room_version"`
+ Event PDU `json:"event"`
+}
+
+type ReqSendJoin struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ OmitMembers bool
+ Event PDU
+ Via string
+}
+
+type ReqSendKnock struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type RespSendJoin struct {
+ AuthChain []PDU `json:"auth_chain"`
+ Event PDU `json:"event"`
+ MembersOmitted bool `json:"members_omitted"`
+ ServersInRoom []string `json:"servers_in_room"`
+ State []PDU `json:"state"`
+}
+
+type RespSendKnock struct {
+ KnockRoomState []PDU `json:"knock_room_state"`
+}
+
+type ReqSendInvite struct {
+ RoomID id.RoomID `json:"-"`
+ UserID id.UserID `json:"-"`
+ Event PDU `json:"event"`
+ InviteRoomState []PDU `json:"invite_room_state"`
+ RoomVersion id.RoomVersion `json:"room_version"`
+}
+
+type RespSendInvite struct {
+ Event PDU `json:"event"`
+}
+
+type ReqMakeLeave struct {
+ RoomID id.RoomID
+ UserID id.UserID
+ Via string
+}
+
+type ReqSendLeave struct {
+ RoomID id.RoomID
+ EventID id.EventID
+ Event PDU
+ Via string
+}
+
+type (
+ ReqMakeKnock = ReqMakeJoin
+ RespMakeKnock = RespMakeJoin
+ RespMakeLeave = RespMakeJoin
+)
+
+func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_join", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) {
+ versions := make([]string, len(req.SupportedVersions))
+ for i, v := range req.SupportedVersions {
+ versions[i] = string(v)
+ }
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID},
+ Query: url.Values{"ver": versions},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_join", req.RoomID, req.EventID},
+ Query: url.Values{
+ "omit_members": {strconv.FormatBool(req.OmitMembers)},
+ },
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.UserID.Homeserver(),
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "invite", req.RoomID, req.UserID},
+ Authenticate: true,
+ RequestJSON: req,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodGet,
+ Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID},
+ Authenticate: true,
+ ResponseJSON: &resp,
+ })
+ return
+}
+
+func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) {
+ _, _, err = c.MakeFullRequest(ctx, RequestParams{
+ ServerName: req.Via,
+ Method: http.MethodPut,
+ Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID},
+ Authenticate: true,
+ RequestJSON: req.Event,
+ })
+ return
+}
+
type URLPath []any
func (fup URLPath) FullPath() []any {
diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go
index 32b4424b..c72933c2 100644
--- a/federation/eventauth/eventauth.go
+++ b/federation/eventauth/eventauth.go
@@ -484,7 +484,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv
}
return ErrCantLeaveWithoutBeingInRoom
}
- if senderMembership != event.MembershipLeave {
+ if senderMembership != event.MembershipJoin {
// 5.5.2. If the sender’s current membership state is not join, reject.
return ErrCantKickWithoutBeingInRoom
}
@@ -505,7 +505,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv
// 5.5.5. Otherwise, reject.
return ErrInsufficientPermissionForKick
case event.MembershipBan:
- if senderMembership != event.MembershipLeave {
+ if senderMembership != event.MembershipJoin {
// 5.6.1. If the sender’s current membership state is not join, reject.
return ErrCantBanWithoutBeingInRoom
}
diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go
new file mode 100644
index 00000000..d316f3c8
--- /dev/null
+++ b/federation/eventauth/eventauth_internal_test.go
@@ -0,0 +1,66 @@
+// Copyright (c) 2026 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//go:build goexperiment.jsonv2
+
+package eventauth
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+type pythonIntTest struct {
+ Name string
+ Input string
+ Expected int64
+}
+
+var pythonIntTests = []pythonIntTest{
+ {"True", `true`, 1},
+ {"False", `false`, 0},
+ {"SmallFloat", `3.1415`, 3},
+ {"SmallFloatRoundDown", `10.999999999999999`, 10},
+ {"SmallFloatRoundUp", `10.9999999999999999`, 11},
+ {"BigFloatRoundDown", `1000000.9999999999`, 1000000},
+ {"BigFloatRoundUp", `1000000.99999999999`, 1000001},
+ {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992},
+ {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994},
+ {"Int64", `9223372036854775807`, 9223372036854775807},
+ {"Int64String", `"9223372036854775807"`, 9223372036854775807},
+ {"String", `"123"`, 123},
+ {"InvalidFloatInString", `"123.456"`, 0},
+ {"StringWithPlusSign", `"+123"`, 123},
+ {"StringWithMinusSign", `"-123"`, -123},
+ {"StringWithSpaces", `" 123 "`, 123},
+ {"StringWithSpacesAndSign", `" -123 "`, -123},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ //{"StringWithUnderscores", `"123_456"`, 123456},
+ {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0},
+ {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0},
+ {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0},
+ {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0},
+ //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456},
+}
+
+func TestParsePythonInt(t *testing.T) {
+ for _, test := range pythonIntTests {
+ t.Run(test.Name, func(t *testing.T) {
+ output := parsePythonInt(gjson.Parse(test.Input))
+ if strings.HasPrefix(test.Name, "Invalid") {
+ assert.Nil(t, output)
+ } else {
+ require.NotNil(t, output)
+ assert.Equal(t, int(test.Expected), *output)
+ }
+ })
+ }
+}
diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go
index cecee5b9..17db6995 100644
--- a/federation/pdu/pdu.go
+++ b/federation/pdu/pdu.go
@@ -123,6 +123,19 @@ func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error)
return evt, nil
}
+func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) {
+ if signature == "" {
+ return
+ }
+ if pdu.Signatures == nil {
+ pdu.Signatures = make(map[string]map[id.KeyID]string)
+ }
+ if _, ok := pdu.Signatures[serverName]; !ok {
+ pdu.Signatures[serverName] = make(map[id.KeyID]string)
+ }
+ pdu.Signatures[serverName][keyID] = signature
+}
+
func marshalCanonical(data any) (jsontext.Value, error) {
marshaledBytes, err := json.Marshal(data)
if err != nil {
diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go
index a7685cc6..04e7c5ef 100644
--- a/federation/pdu/signature.go
+++ b/federation/pdu/signature.go
@@ -28,13 +28,7 @@ func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.Key
return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err)
}
signature := ed25519.Sign(privateKey, rawJSON)
- if pdu.Signatures == nil {
- pdu.Signatures = make(map[string]map[id.KeyID]string)
- }
- if _, ok := pdu.Signatures[serverName]; !ok {
- pdu.Signatures[serverName] = make(map[id.KeyID]string)
- }
- pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature)
+ pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature))
return nil
}
diff --git a/federation/resolution.go b/federation/resolution.go
index 81e19cfb..a3188266 100644
--- a/federation/resolution.go
+++ b/federation/resolution.go
@@ -80,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS
} else if wellKnown != nil {
output.Expires = expiry
output.HostHeader = wellKnown.Server
- hostname, port, ok = ParseServerName(wellKnown.Server)
+ wkHost, wkPort, ok := ParseServerName(wellKnown.Server)
+ if ok {
+ hostname, port = wkHost, wkPort
+ }
// Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known
if net.ParseIP(hostname) != nil || port != 0 {
if port == 0 {
diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go
index 633a0f66..f99fc6cf 100644
--- a/federation/serverauth_test.go
+++ b/federation/serverauth_test.go
@@ -19,7 +19,7 @@ import (
func TestServerKeyResponse_VerifySelfSignature(t *testing.T) {
cli := federation.NewClient("", nil, nil)
ctx := context.Background()
- for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} {
+ for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} {
t.Run(name, func(t *testing.T) {
resp, err := cli.ServerKeys(ctx, name)
require.NoError(t, err)
diff --git a/filter.go b/filter.go
index c6c8211b..54973dab 100644
--- a/filter.go
+++ b/filter.go
@@ -57,7 +57,7 @@ type FilterPart struct {
// Validate checks if the filter contains valid property values
func (filter *Filter) Validate() error {
if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation {
- return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]")
+ return errors.New("bad event_format value")
}
return nil
}
diff --git a/format/htmlparser.go b/format/htmlparser.go
index e5f92896..e0507d93 100644
--- a/format/htmlparser.go
+++ b/format/htmlparser.go
@@ -93,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string
}
}
+func onlyBacktickCount(line string) (count int) {
+ for i := 0; i < len(line); i++ {
+ if line[i] != '`' {
+ return -1
+ }
+ count++
+ }
+ return
+}
+
+func DefaultMonospaceBlockConverter(code, language string, ctx Context) string {
+ if len(code) == 0 || code[len(code)-1] != '\n' {
+ code += "\n"
+ }
+ fence := "```"
+ for line := range strings.SplitSeq(code, "\n") {
+ count := onlyBacktickCount(strings.TrimSpace(line))
+ if count >= len(fence) {
+ fence = strings.Repeat("`", count+1)
+ }
+ }
+ return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence)
+}
+
// HTMLParser is a somewhat customizable Matrix HTML parser.
type HTMLParser struct {
PillConverter PillConverter
@@ -348,10 +372,7 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
}
- if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' {
- preStr += "\n"
- }
- return fmt.Sprintf("```%s\n%s```", language, preStr)
+ return DefaultMonospaceBlockConverter(preStr, language, ctx)
default:
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
diff --git a/go.mod b/go.mod
index c2acc7d6..49a1d4e4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,42 +1,42 @@
module maunium.net/go/mautrix
-go 1.24.0
+go 1.25.0
-toolchain go1.25.4
+toolchain go1.26.0
require (
- filippo.io/edwards25519 v1.1.0
+ filippo.io/edwards25519 v1.2.0
github.com/chzyer/readline v1.5.1
github.com/coder/websocket v1.8.14
- github.com/lib/pq v1.10.9
- github.com/mattn/go-sqlite3 v1.14.32
+ github.com/lib/pq v1.11.2
+ github.com/mattn/go-sqlite3 v1.14.34
github.com/rs/xid v1.6.0
github.com/rs/zerolog v1.34.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
- github.com/yuin/goldmark v1.7.13
- go.mau.fi/util v0.9.3
+ github.com/yuin/goldmark v1.7.16
+ go.mau.fi/util v0.9.6
go.mau.fi/zeroconfig v0.2.0
- golang.org/x/crypto v0.44.0
- golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
- golang.org/x/net v0.47.0
- golang.org/x/sync v0.18.0
+ golang.org/x/crypto v0.48.0
+ golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa
+ golang.org/x/net v0.50.0
+ golang.org/x/sync v0.19.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
)
require (
- github.com/coreos/go-systemd/v22 v22.5.0 // indirect
+ github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
- github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 // indirect
+ github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
- golang.org/x/sys v0.38.0 // indirect
- golang.org/x/text v0.31.0 // indirect
+ golang.org/x/sys v0.41.0 // indirect
+ golang.org/x/text v0.34.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)
diff --git a/go.sum b/go.sum
index b5fbf85f..871a5156 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
-filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
+filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
@@ -10,13 +10,14 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
-github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
+github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo=
+github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
+github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
@@ -24,10 +25,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
-github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 h1:QTvNkZ5ylY0PGgA+Lih+GdboMLY/G9SEGLMEGVjTVA4=
-github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
+github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
+github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -49,28 +50,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA=
-github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
-go.mau.fi/util v0.9.3 h1:aqNF8KDIN8bFpFbybSk+mEBil7IHeBwlujfyTnvP0uU=
-go.mau.fi/util v0.9.3/go.mod h1:krWWfBM1jWTb5f8NCa2TLqWMQuM81X7TGQjhMjBeXmQ=
+github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE=
+github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
+go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
+go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU=
go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w=
-golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
-golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
-golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
-golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
-golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
-golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
-golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
-golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
+golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
+golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
+golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
+golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
+golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
+golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
-golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
-golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
+golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
+golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
+golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
diff --git a/id/contenturi.go b/id/contenturi.go
index e6a313f5..67127b6c 100644
--- a/id/contenturi.go
+++ b/id/contenturi.go
@@ -17,8 +17,14 @@ import (
)
var (
- InvalidContentURI = errors.New("invalid Matrix content URI")
- InputNotJSONString = errors.New("input doesn't look like a JSON string")
+ ErrInvalidContentURI = errors.New("invalid Matrix content URI")
+ ErrInputNotJSONString = errors.New("input doesn't look like a JSON string")
+)
+
+// Deprecated: use variables prefixed with Err
+var (
+ InvalidContentURI = ErrInvalidContentURI
+ InputNotJSONString = ErrInputNotJSONString
)
// ContentURIString is a string that's expected to be a Matrix content URI.
@@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !strings.HasPrefix(uri, "mxc://") {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = uri[6 : 6+index]
parsed.FileID = uri[6+index+1:]
@@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) {
if len(uri) == 0 {
return
} else if !bytes.HasPrefix(uri, mxcBytes) {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 {
- err = InvalidContentURI
+ err = ErrInvalidContentURI
} else {
parsed.Homeserver = string(uri[6 : 6+index])
parsed.FileID = string(uri[6+index+1:])
@@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) {
*uri = ContentURI{}
return nil
} else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' {
- return InputNotJSONString
+ return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString)
}
parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1])
if err != nil {
diff --git a/id/crypto.go b/id/crypto.go
index 355a84a8..ee857f78 100644
--- a/id/crypto.go
+++ b/id/crypto.go
@@ -53,6 +53,34 @@ const (
KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2"
)
+type KeySource string
+
+func (source KeySource) String() string {
+ return string(source)
+}
+
+func (source KeySource) Int() int {
+ switch source {
+ case KeySourceDirect:
+ return 100
+ case KeySourceBackup:
+ return 90
+ case KeySourceImport:
+ return 80
+ case KeySourceForward:
+ return 50
+ default:
+ return 0
+ }
+}
+
+const (
+ KeySourceDirect KeySource = "direct"
+ KeySourceBackup KeySource = "backup"
+ KeySourceImport KeySource = "import"
+ KeySourceForward KeySource = "forward"
+)
+
// BackupVersion is an arbitrary string that identifies a server side key backup.
type KeyBackupVersion string
diff --git a/id/matrixuri.go b/id/matrixuri.go
index 8f5ec849..d5c78bc7 100644
--- a/id/matrixuri.go
+++ b/id/matrixuri.go
@@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{
func (uri *MatrixURI) getQuery() url.Values {
q := make(url.Values)
- if uri.Via != nil && len(uri.Via) > 0 {
+ if len(uri.Via) > 0 {
q["via"] = uri.Via
}
if len(uri.Action) > 0 {
diff --git a/id/trust.go b/id/trust.go
index 04f6e36b..6255093e 100644
--- a/id/trust.go
+++ b/id/trust.go
@@ -16,6 +16,7 @@ type TrustState int
const (
TrustStateBlacklisted TrustState = -100
+ TrustStateDeviceKeyMismatch TrustState = -5
TrustStateUnset TrustState = 0
TrustStateUnknownDevice TrustState = 10
TrustStateForwarded TrustState = 20
@@ -23,7 +24,7 @@ const (
TrustStateCrossSignedTOFU TrustState = 100
TrustStateCrossSignedVerified TrustState = 200
TrustStateVerified TrustState = 300
- TrustStateInvalid TrustState = (1 << 31) - 1
+ TrustStateInvalid TrustState = -2147483647
)
func (ts *TrustState) UnmarshalText(data []byte) error {
@@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState {
switch strings.ToLower(val) {
case "blacklisted":
return TrustStateBlacklisted
+ case "device-key-mismatch":
+ return TrustStateDeviceKeyMismatch
case "unverified":
return TrustStateUnset
case "cross-signed-untrusted":
@@ -67,6 +70,8 @@ func (ts TrustState) String() string {
switch ts {
case TrustStateBlacklisted:
return "blacklisted"
+ case TrustStateDeviceKeyMismatch:
+ return "device-key-mismatch"
case TrustStateUnset:
return "unverified"
case TrustStateCrossSignedUntrusted:
diff --git a/id/userid.go b/id/userid.go
index 859d2358..726a0d58 100644
--- a/id/userid.go
+++ b/id/userid.go
@@ -219,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) {
for i := 0; i < len(strBytes); i++ {
b := strBytes[i]
if !isValidByte(b) {
- return "", fmt.Errorf("Byte pos %d: Invalid byte", i)
+ return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b)
}
if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _
if i+1 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after underscore at %d", i)
}
if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping
- return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i)
+ return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i)
}
if strBytes[i+1] == '_' {
outputBuffer.WriteByte('_')
@@ -237,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) {
i++ // skip next byte since we just handled it
} else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8
if i+2 >= len(strBytes) {
- return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i)
+ return "", fmt.Errorf("unexpected end of string after equals sign at %d", i)
}
dst := make([]byte, 1)
_, err := hex.Decode(dst, strBytes[i+1:i+3])
diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go
index 07e30810..4d2bc7cf 100644
--- a/mediaproxy/mediaproxy.go
+++ b/mediaproxy/mediaproxy.go
@@ -95,9 +95,13 @@ func (d *GetMediaResponseCallback) GetContentType() string {
return d.ContentType
}
+type FileMeta struct {
+ ContentType string
+ ReplacementFile string
+}
+
type GetMediaResponseFile struct {
- Callback func(w *os.File) error
- ContentType string
+ Callback func(w *os.File) (*FileMeta, error)
}
type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error)
@@ -139,6 +143,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx
}
mp.FederationRouter = http.NewServeMux()
mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation)
+ mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation)
mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion)
mp.ClientMediaRouter = http.NewServeMux()
mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia)
@@ -453,23 +458,35 @@ func doTempFileDownload(
if err != nil {
return false, fmt.Errorf("failed to create temp file: %w", err)
}
+ origTempFile := tempFile
defer func() {
- _ = tempFile.Close()
- _ = os.Remove(tempFile.Name())
+ _ = origTempFile.Close()
+ _ = os.Remove(origTempFile.Name())
}()
- err = data.Callback(tempFile)
+ meta, err := data.Callback(tempFile)
if err != nil {
return false, err
}
- _, err = tempFile.Seek(0, io.SeekStart)
- if err != nil {
- return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ if meta.ReplacementFile != "" {
+ tempFile, err = os.Open(meta.ReplacementFile)
+ if err != nil {
+ return false, fmt.Errorf("failed to open replacement file: %w", err)
+ }
+ defer func() {
+ _ = tempFile.Close()
+ _ = os.Remove(origTempFile.Name())
+ }()
+ } else {
+ _, err = tempFile.Seek(0, io.SeekStart)
+ if err != nil {
+ return false, fmt.Errorf("failed to seek to start of temp file: %w", err)
+ }
}
fileInfo, err := tempFile.Stat()
if err != nil {
return false, fmt.Errorf("failed to stat temp file: %w", err)
}
- mimeType := data.ContentType
+ mimeType := meta.ContentType
if mimeType == "" {
buf := make([]byte, 512)
n, err := tempFile.Read(buf)
diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go
index e52c387a..507c24a5 100644
--- a/mockserver/mockserver.go
+++ b/mockserver/mockserver.go
@@ -231,7 +231,7 @@ func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) {
}
func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) {
- var req mautrix.UploadCrossSigningKeysReq
+ var req mautrix.UploadCrossSigningKeysReq[any]
mustDecode(r, &req)
userID := ms.getUserID(r).UserID
diff --git a/pushrules/action.go b/pushrules/action.go
index 9838e88b..b5a884b2 100644
--- a/pushrules/action.go
+++ b/pushrules/action.go
@@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error {
if ok {
action.Action = ActionSetTweak
action.Tweak = PushActionTweak(tweak)
- action.Value, _ = val["value"]
+ action.Value = val["value"]
}
}
return nil
diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go
index 0d3eaf7a..37af3e34 100644
--- a/pushrules/condition_test.go
+++ b/pushrules/condition_test.go
@@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi
}
}
-func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition {
- return &pushrules.PushCondition{
- Kind: pushrules.KindEventPropertyContains,
- Key: key,
- Value: value,
- }
-}
-
func TestPushCondition_Match_InvalidKind(t *testing.T) {
condition := &pushrules.PushCondition{
Kind: pushrules.PushCondKind("invalid"),
diff --git a/requests.go b/requests.go
index f0287b3c..cc8b7266 100644
--- a/requests.go
+++ b/requests.go
@@ -66,14 +66,14 @@ const (
)
// ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
-type ReqRegister struct {
+type ReqRegister[UIAType any] struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"`
InhibitLogin bool `json:"inhibit_login,omitempty"`
RefreshToken bool `json:"refresh_token,omitempty"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
// Type for registration, only used for appservice user registrations
// https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions
@@ -320,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 {
return ""
}
-type UploadCrossSigningKeysReq struct {
+type UploadCrossSigningKeysReq[UIAType any] struct {
Master CrossSigningKeys `json:"master_key"`
SelfSigning CrossSigningKeys `json:"self_signing_key"`
UserSigning CrossSigningKeys `json:"user_signing_key"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type KeyMap map[id.DeviceKeyID]string
@@ -367,13 +367,12 @@ type ReqSendToDevice struct {
}
type ReqSendEvent struct {
- Timestamp int64
- TransactionID string
- UnstableDelay time.Duration
-
- DontEncrypt bool
-
- MeowEventID id.EventID
+ Timestamp int64
+ TransactionID string
+ UnstableDelay time.Duration
+ UnstableStickyDuration time.Duration
+ DontEncrypt bool
+ MeowEventID id.EventID
}
type ReqDelayedEvents struct {
@@ -393,14 +392,14 @@ type ReqDeviceInfo struct {
}
// ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid
-type ReqDeleteDevice struct {
- Auth interface{} `json:"auth,omitempty"`
+type ReqDeleteDevice[UIAType any] struct {
+ Auth UIAType `json:"auth,omitempty"`
}
// ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices
-type ReqDeleteDevices struct {
+type ReqDeleteDevices[UIAType any] struct {
Devices []id.DeviceID `json:"devices"`
- Auth interface{} `json:"auth,omitempty"`
+ Auth UIAType `json:"auth,omitempty"`
}
type ReqPutPushRule struct {
diff --git a/responses.go b/responses.go
index e7b6b75e..4fbe1fbc 100644
--- a/responses.go
+++ b/responses.go
@@ -258,15 +258,13 @@ func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) {
type RespMutualRooms struct {
Joined []id.RoomID `json:"joined"`
NextBatch string `json:"next_batch,omitempty"`
+ Count int `json:"count,omitempty"`
}
type RespRoomSummary struct {
PublicRoomInfo
- Membership event.Membership `json:"membership,omitempty"`
- RoomVersion id.RoomVersion `json:"room_version,omitempty"`
- Encryption id.Algorithm `json:"encryption,omitempty"`
- AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
+ Membership event.Membership `json:"membership,omitempty"`
UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"`
UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"`
@@ -344,6 +342,13 @@ type LazyLoadSummary struct {
InvitedMemberCount *int `json:"m.invited_member_count,omitempty"`
}
+func (lls *LazyLoadSummary) MemberCount() int {
+ if lls == nil {
+ return 0
+ }
+ return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount)
+}
+
func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool {
if lls == other {
return true
@@ -685,6 +690,10 @@ type PublicRoomInfo struct {
RoomType event.RoomType `json:"room_type"`
Topic string `json:"topic,omitempty"`
WorldReadable bool `json:"world_readable"`
+
+ RoomVersion id.RoomVersion `json:"room_version,omitempty"`
+ Encryption id.Algorithm `json:"encryption,omitempty"`
+ AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"`
}
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
diff --git a/room.go b/room.go
index c3ddb7e6..4292bff5 100644
--- a/room.go
+++ b/room.go
@@ -5,8 +5,6 @@ import (
"maunium.net/go/mautrix/id"
)
-type RoomStateMap = map[event.Type]map[string]*event.Event
-
// Room represents a single Matrix room.
type Room struct {
ID id.RoomID
@@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) {
// GetStateEvent returns the state event for the given type/state_key combo, or nil.
func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event {
- stateEventMap, _ := room.State[eventType]
- evt, _ := stateEventMap[stateKey]
+ stateEventMap := room.State[eventType]
+ evt := stateEventMap[stateKey]
return evt
}
diff --git a/statestore.go b/statestore.go
index c6267c5b..2bd498dd 100644
--- a/statestore.go
+++ b/statestore.go
@@ -129,6 +129,7 @@ func NewMemoryStateStore() StateStore {
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
Create: make(map[id.RoomID]*event.Event),
+ JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent),
}
}
diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go
index c360acab..0925b748 100644
--- a/synapseadmin/roomapi.go
+++ b/synapseadmin/roomapi.go
@@ -75,8 +75,7 @@ type RespListRooms struct {
// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api
func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) {
var resp RespListRooms
- var reqURL string
- reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
+ reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery())
_, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp)
return resp, err
}
diff --git a/url.go b/url.go
index d888956a..91b3d49d 100644
--- a/url.go
+++ b/url.go
@@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any {
// and appservice user ID set already.
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) {
- if urlQuery != nil {
- for k, v := range urlQuery {
- q.Set(k, v)
- }
+ for k, v := range urlQuery {
+ q.Set(k, v)
}
})
}
diff --git a/version.go b/version.go
index f6d20c3f..f00bbf39 100644
--- a/version.go
+++ b/version.go
@@ -8,7 +8,7 @@ import (
"strings"
)
-const Version = "v0.26.0"
+const Version = "v0.26.3"
var GoModVersion = ""
var Commit = ""
diff --git a/versions.go b/versions.go
index 2aaf6399..61b2e4ea 100644
--- a/versions.go
+++ b/versions.go
@@ -60,16 +60,18 @@ type UnstableFeature struct {
}
var (
- FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
- FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
- FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
- FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
- FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
- FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
- FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
- FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
- FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
- FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
+ FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17}
+ FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
+ FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111}
+ FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"}
+ FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"}
+ FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"}
+ FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"}
+ FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/}
+ FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"}
+ FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116}
+ FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"}
BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
@@ -79,6 +81,7 @@ var (
BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"}
BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"}
BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"}
+ BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
@@ -123,6 +126,7 @@ var (
SpecV114 = MustParseSpecVersion("v1.14")
SpecV115 = MustParseSpecVersion("v1.15")
SpecV116 = MustParseSpecVersion("v1.16")
+ SpecV117 = MustParseSpecVersion("v1.17")
)
func (svf SpecVersionFormat) String() string {