Compare commits

..

No commits in common. "main" and "v0.25.0" have entirely different histories.

231 changed files with 2291 additions and 12158 deletions

View file

@ -10,12 +10,12 @@ jobs:
runs-on: ubuntu-latest
name: Lint (latest)
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v6
uses: actions/setup-go@v5
with:
go-version: "1.26"
go-version: "1.24"
cache: true
- name: Install libolm
@ -24,7 +24,6 @@ jobs:
- name: Install goimports
run: |
go install golang.org/x/tools/cmd/goimports@latest
go install honnef.co/go/tools/cmd/staticcheck@latest
export PATH="$HOME/go/bin:$PATH"
- name: Run pre-commit
@ -35,14 +34,14 @@ jobs:
strategy:
fail-fast: false
matrix:
go-version: ["1.25", "1.26"]
name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm)
go-version: ["1.24", "1.25"]
name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm)
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v6
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: true
@ -61,28 +60,28 @@ jobs:
- name: Test
run: go test -json -v ./... 2>&1 | gotestfmt
- name: Test (jsonv2)
env:
GOEXPERIMENT: jsonv2
run: go test -json -v ./... 2>&1 | gotestfmt
build-goolm:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
go-version: ["1.25", "1.26"]
name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm)
go-version: ["1.24", "1.25"]
name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm)
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v6
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: true
- name: Set up gotestfmt
uses: GoTestTools/gotestfmt-action@v2
with:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Build
run: |
rm -rf crypto/libolm

View file

@ -17,7 +17,7 @@ jobs:
lock-stale:
runs-on: ubuntu-latest
steps:
- uses: dessant/lock-threads@v6
- uses: dessant/lock-threads@v5
id: lock
with:
issue-inactive-days: 90

View file

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@ -9,7 +9,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/tekwizely/pre-commit-golang
rev: v1.0.0-rc.4
rev: v1.0.0-rc.1
hooks:
- id: go-imports-repo
args:
@ -18,7 +18,8 @@ repos:
- "-w"
- id: go-vet-repo-mod
- id: go-mod-tidy
- id: go-staticcheck-repo-mod
# TODO enable this
#- id: go-staticcheck-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.4.2
@ -26,4 +27,3 @@ repos:
- id: prevent-literal-http-methods
- id: zerolog-ban-global-log
- id: zerolog-ban-msgf
- id: zerolog-use-stringer

View file

@ -1,202 +1,4 @@
## v0.26.3 (2026-02-16)
* Bumped minimum Go version to 1.25.
* *(client)* Added fields for sending [MSC4354] sticky events.
* *(bridgev2)* Added automatic message request accepting when sending message.
* *(mediaproxy)* Added support for federation thumbnail endpoint.
* *(crypto/ssss)* Improved support for recovery keys with slightly broken
metadata.
* *(crypto)* Changed key import to call session received callback even for
sessions that already exist in the database.
* *(appservice)* Fixed building websocket URL accidentally using file path
separators instead of always `/`.
* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field.
* *(client)* Fixed incorrect context usage in async uploads.
* *(crypto)* Fixed panic when passing invalid input to megolm message index
parser used for debugging.
* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned
up properly.
[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354
## v0.26.2 (2026-01-16)
* *(bridgev2)* Added chunked portal deletion to avoid database locks when
deleting large portals.
* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata
as per [MSC4392].
* *(bridgev2/login)* Added `default_value` for user input fields.
* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested
HTTP client settings and to reset active connections of the network connector.
* *(bridgev2)* Added interface to let network connectors get the provisioning
API HTTP router and add new endpoints.
* *(event)* Added blurhash field to Beeper link preview objects.
* *(event)* Added [MSC4391] support for bot commands.
* *(event)* Dropped [MSC4332] support for bot commands.
* *(client)* Changed media download methods to return an error if the provided
MXC URI is empty.
* *(client)* Stabilized support for [MSC4323].
* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events.
* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with
a portal re-ID call.
[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391
[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392
## v0.26.1 (2025-12-16)
* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return
the mime type from the callback rather than in the return get media return
value. The callback can now also redirect the caller to a different file.
* *(federation)* Added join/knock/leave functions
(thanks to [@nexy7574] in [#422]).
* *(federation/eventauth)* Fixed various incorrect checks.
* *(client)* Added backoff for retrying media uploads to external URLs
(with MSC3870).
* *(bridgev2/config)* Added support for overriding config fields using
environment variables.
* *(bridgev2/commands)* Added command to mute chat on remote network.
* *(bridgev2)* Added interface for network connectors to redirect to a different
user ID when handling an invite from Matrix.
* *(bridgev2)* Added interface for signaling message request status of portals.
* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag
is set in chat info.
* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if
bridging the new one is successful.
* *(bridgev2/mxmain)* Improved error message when trying to run bridge with
pre-megabridge database when no database migration exists.
* *(bridgev2)* Improved reliability of database migration when enabling split
portals.
* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats.
* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public
portal rooms.
* *(bridgev2)* Stopped hardcoding room versions in favor of checking
server capabilities to determine appropriate `/createRoom` parameters.
[#422]: https://github.com/mautrix/go/pull/422
## v0.26.0 (2025-11-16)
* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent`
has been able to do the same for a while now.
* *(client,federation)* Added size limits for responses to make it safer to send
requests to untrusted servers.
* *(client)* Added wrapper for `/admin/whois` client API
(thanks to [@nexy7574] in [#411]).
* *(synapseadmin)* Added `force_purge` option to DeleteRoom
(thanks to [@nexy7574] in [#420]).
* *(statestore)* Added saving join rules for rooms.
* *(bridgev2)* Added optional automatic rollback of room state if bridging the
change to the remote network fails.
* *(bridgev2)* Added management room notices if transient disconnect state
doesn't resolve within 3 minutes.
* *(bridgev2)* Added interface to signal that certain participants couldn't be
invited when creating a group.
* *(bridgev2)* Added `select` type for user input fields in login.
* *(bridgev2)* Added interface to let network connector customize personal
filtering space.
* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to
other bots.
* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever
possible.
* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names,
and encrypted files.
* *(bridgev2/commands)* Added command to resync a single portal.
* *(bridgev2/commands)* Added create group command.
* *(bridgev2/config)* Added option to limit maximum number of logins.
* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room
is public.
* *(bridgev2/disappear)* Changed read receipt handling to only start
disappearing timers for messages up to the read message (note: may not work in
all cases if the read receipt points at an unknown event).
* *(event/reply)* Changed plaintext reply fallback removal to only happen when
an HTML reply fallback is removed successfully.
* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run.
* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages.
* *(federation)* Fixed HTTP method for sending transactions
(thanks to [@nexy7574] in [#426]).
* *(federation)* Fixed response body being closed even when using `DontReadBody`
parameter.
* *(federation)* Fixed validating auth for requests with query params.
* *(federation/eventauth)* Fixed typo causing restricted joins to not work.
[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169
[#411]: github.com/mautrix/go/pull/411
[#420]: github.com/mautrix/go/pull/420
[#426]: github.com/mautrix/go/pull/426
## v0.25.2 (2025-10-16)
* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into
`ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old
behavior, but most users likely want the relaxed version, as there are real
users whose user IDs aren't valid under the strict rules.
* *(crypto)* Added helper methods for generating and verifying with recovery
keys.
* *(bridgev2/matrix)* Added config option to automatically generate a recovery
key for the bridge bot and self-sign the bridge's device.
* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode
for encryption with standard servers like Synapse.
* *(bridgev2)* Added optional support for implicit read receipts.
* *(bridgev2)* Added interface for deleting chats on remote network.
* *(bridgev2)* Added local enforcement of media duration and size limits.
* *(bridgev2)* Extended event duration logging to log any event taking too long.
* *(bridgev2)* Improved validation in group creation provisioning API.
* *(event)* Added event type constant for poll end events.
* *(client)* Added wrapper for searching user directory.
* *(client)* Improved support for managing [MSC4140] delayed events.
* *(crypto/helper)* Changed default sync handling to not block on waiting for
decryption keys. On initial sync, keys won't be requested at all by default.
* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1).
* *(bridgev2)* Fixed various bugs with migrating to split portals.
* *(event)* Fixed poll start events having incorrect null `m.relates_to`.
* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling.
* *(federation)* Fixed various bugs in event auth.
## v0.25.1 (2025-09-16)
* *(client)* Fixed HTTP method of delete devices API call
(thanks to [@fmseals] in [#393]).
* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints
(thanks to [@nexy7574] in [#407]).
* *(client)* Stabilized support for extensible profiles.
* *(client)* Stabilized support for `state_after` in sync.
* *(client)* Removed deprecated MSC2716 requests.
* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if
the content struct doesn't implement `Relatable`.
* *(crypto)* Changed olm unwedging to ignore newly created sessions if they
haven't been used successfully in either direction.
* *(federation)* Added utilities for generating, parsing, validating and
authorizing PDUs.
* Note: the new PDU code depends on `GOEXPERIMENT=jsonv2`
* *(event)* Added `is_animated` flag from [MSC4230] to file info.
* *(event)* Added types for [MSC4332]: In-room bot commands.
* *(event)* Added missing poll end event type for [MSC3381].
* *(appservice)* Fixed URLs not being escaped properly when using unix socket
for homeserver connections.
* *(format)* Added more helpers for forming markdown links.
* *(event,bridgev2)* Added support for Beeper's disappearing message state event.
* *(bridgev2)* Redesigned group creation interface and added support in commands
and provisioning API.
* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to
get an old event. The method is best effort only, as some configurations don't
allow fetching old events.
* *(bridgev2)* Added shared logic for provisioning that can be reused by the
API, commands and other sources.
* *(bridgev2)* Fixed mentions and URL previews not being copied over when
caption and media are merged.
* *(bridgev2)* Removed config option to change provisioning API prefix, which
had already broken in the previous release.
[@fmseals]: https://github.com/fmseals
[#393]: https://github.com/mautrix/go/pull/393
[#407]: https://github.com/mautrix/go/pull/407
[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230
[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323
[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332
## v0.25.0 (2025-08-16)
## v0.25.0 (unreleased)
* Bumped minimum Go version to 1.24.
* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux
@ -437,7 +239,6 @@
[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156
[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190
[#288]: https://github.com/mautrix/go/pull/288
[@onestacked]: https://github.com/onestacked
## v0.22.0 (2024-11-16)

View file

@ -1,9 +1,8 @@
# mautrix-go
[![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix)
A Golang Matrix framework. Used by [gomuks](https://gomuks.app),
[go-neb](https://github.com/matrix-org/go-neb),
[mautrix-whatsapp](https://github.com/mautrix/whatsapp)
A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks),
[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp)
and others.
Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net)
@ -14,10 +13,9 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or
In addition to the basic client API features the original project has, this framework also has:
* Appservice support (Intent API like mautrix-python, room state storage, etc)
* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc)
* End-to-end encryption support (incl. interactive SAS verification)
* High-level module for building puppeting bridges
* Partial federation module (making requests, PDU processing and event authorization)
* A media proxy server which can be used to expose anything as a Matrix media repo
* High-level module for building chat clients
* Wrapper functions for the Synapse admin API
* Structs for parsing event content
* Helpers for parsing and generating Matrix HTML

View file

@ -334,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error {
} else if as.hsURLForClient.Scheme == "" {
as.hsURLForClient.Scheme = "https"
}
as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath()
as.hsURLForClient.RawPath = parsedURL.EscapedPath()
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar}
@ -360,7 +360,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
AccessToken: as.Registration.AppToken,
UserAgent: as.UserAgent,
StateStore: as.StateStore,
Log: as.Log.With().Stringer("as_user_id", userID).Logger(),
Log: as.Log.With().Str("as_user_id", userID.String()).Logger(),
Client: as.HTTPClient,
DefaultHTTPRetries: as.DefaultHTTPRetries,
SpecVersions: as.SpecVersions,

View file

@ -201,7 +201,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
}
err := evt.Content.ParseRaw(evt.Type)
if errors.Is(err, event.ErrUnsupportedContentType) {
log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event")
log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event")
} else if err != nil {
log.Warn().Err(err).
Str("event_id", evt.ID.String()).

View file

@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
}
func (intent *IntentAPI) Register(ctx context.Context) error {
_, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{
_, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{
Username: intent.Localpart,
Type: mautrix.AuthTypeAppservice,
InhibitLogin: true,
@ -214,31 +214,23 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any {
}
}
func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...)
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON)
}
func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...)
}
// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead
func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
}
func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
if eventType != event.StateMember || stateKey != string(intent.UserID) {
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
@ -247,12 +239,15 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...)
return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON)
}
// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead
func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
if err := intent.EnsureJoined(ctx, roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts)
}
func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
@ -311,7 +306,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
return &mautrix.RespJoinRoom{RoomID: roomID}, err
return &mautrix.RespJoinRoom{}, err
}
return intent.Client.JoinRoomByID(ctx, roomID)
}

View file

@ -11,10 +11,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"sync"
"sync/atomic"
@ -56,7 +55,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
var prefixMessage string
for unwrappedErr != nil {
errorData, jsonErr = json.Marshal(unwrappedErr)
if len(errorData) > 2 && jsonErr == nil {
if errorData != nil && len(errorData) > 2 && jsonErr == nil {
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
prefixMessage = strings.TrimRight(prefixMessage, ": ")
break
@ -293,16 +292,10 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error)
as.Log.Debug().Msg("Ignoring non-text message from websocket")
continue
}
data, err := io.ReadAll(reader)
if err != nil {
as.Log.Debug().Err(err).Msg("Error reading data from websocket")
stopFunc(parseCloseError(err))
return
}
var msg WebsocketMessage
err = json.Unmarshal(data, &msg)
err = json.NewDecoder(reader).Decode(&msg)
if err != nil {
as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket")
as.Log.Debug().Err(err).Msg("Error reading JSON from websocket")
stopFunc(parseCloseError(err))
return
}
@ -374,7 +367,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
copiedURL := *as.hsURLForClient
parsed = &copiedURL
}
parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
if parsed.Scheme == "http" {
parsed.Scheme = "ws"
} else if parsed.Scheme == "https" {
@ -419,7 +412,6 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
}
})
}
ws.SetReadLimit(50 * 1024 * 1024)
as.ws = ws
as.StopWebsocket = stopFunc
as.PrepareWebsocket()

View file

@ -9,14 +9,11 @@ package bridgev2
import (
"context"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exhttp"
"go.mau.fi/util/exsync"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
@ -54,7 +51,6 @@ type Bridge struct {
Background bool
ExternallyManagedDB bool
stopping atomic.Bool
wakeupBackfillQueue chan struct{}
stopBackfillQueue *exsync.Event
@ -130,7 +126,6 @@ func (br *Bridge) Start(ctx context.Context) error {
func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error {
br.Background = true
br.stopping.Store(false)
err := br.StartConnectors(ctx)
if err != nil {
return err
@ -166,7 +161,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
case <-time.After(20 * time.Second):
case <-ctx.Done():
}
br.stopping.Store(true)
return nil
} else {
br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode")
@ -176,7 +170,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
func (br *Bridge) StartConnectors(ctx context.Context) error {
br.Log.Info().Msg("Starting bridge")
br.stopping.Store(false)
if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil {
br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background())
br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx)
@ -189,11 +182,7 @@ func (br *Bridge) StartConnectors(ctx context.Context) error {
}
}
if !br.Background {
var postMigrate func()
br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx)
if postMigrate != nil {
defer postMigrate()
}
br.didSplitPortals = br.MigrateToSplitPortals(ctx)
}
br.Log.Info().Msg("Starting Matrix connector")
err := br.Matrix.Start(ctx)
@ -282,64 +271,20 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps b
Msg("Resent bridge info to all portals")
}
func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) {
func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool {
log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger()
ctx = log.WithContext(ctx)
if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" {
return false, nil
return false
}
affected, err := br.DB.Portal.MigrateToSplitPortals(ctx)
if err != nil {
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals")
os.Exit(31)
return false, nil
log.Err(err).Msg("Failed to migrate portals")
return false
}
log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals")
affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx)
if err != nil {
log.Err(err).Msg("Failed to fix parent portals after split portal migration")
os.Exit(31)
return false, nil
}
log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration")
withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx)
if err != nil {
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate")
os.Exit(31)
return false, nil
}
var roomsToDelete []id.RoomID
log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver")
for _, portal := range withoutReceiver {
if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil {
log.Err(err).
Str("portal_id", string(portal.ID)).
Stringer("mxid", portal.MXID).
Msg("Failed to delete portal database row that failed to migrate")
} else if portal.MXID != "" {
log.Debug().
Str("portal_id", string(portal.ID)).
Stringer("mxid", portal.MXID).
Msg("Marked portal room for deletion from homeserver")
roomsToDelete = append(roomsToDelete, portal.MXID)
} else {
log.Debug().
Str("portal_id", string(portal.ID)).
Msg("Deleted portal row with no Matrix room")
}
}
br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true")
log.Info().Msg("Finished split portal migration successfully")
return affected > 0, func() {
for _, roomID := range roomsToDelete {
if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil {
log.Err(err).
Stringer("mxid", roomID).
Msg("Failed to delete portal room that failed to migrate")
}
}
log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate")
}
return affected > 0
}
func (br *Bridge) StartLogins(ctx context.Context) error {
@ -374,46 +319,6 @@ func (br *Bridge) StartLogins(ctx context.Context) error {
return nil
}
func (br *Bridge) ResetNetworkConnections() {
nrn, ok := br.Network.(NetworkResettingNetwork)
if ok {
br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections")
nrn.ResetNetworkConnections()
return
}
br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually")
for _, login := range br.GetAllCachedUserLogins() {
login.Log.Debug().Msg("Disconnecting and recreating client for network reset")
ctx := login.Log.WithContext(br.BackgroundCtx)
login.Client.Disconnect()
err := login.recreateClient(ctx)
if err != nil {
login.Log.Err(err).Msg("Failed to recreate client during network reset")
login.BridgeState.Send(status.BridgeState{
StateEvent: status.StateUnknownError,
Error: "bridgev2-network-reset-fail",
Info: map[string]any{"go_error": err.Error()},
})
} else {
login.Client.Connect(ctx)
}
}
br.Log.Info().Msg("Finished resetting all user logins")
}
func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings {
mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings)
if ok {
return mchs.GetHTTPClientSettings()
}
return exhttp.SensibleClientSettings
}
func (br *Bridge) IsStopping() bool {
return br.stopping.Load()
}
func (br *Bridge) Stop() {
br.stop(false, 0)
}
@ -424,7 +329,6 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) {
func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) {
br.Log.Info().Msg("Shutting down bridge")
br.stopping.Store(true)
br.DisappearLoop.Stop()
br.stopBackfillQueue.Set()
br.Matrix.PreStop()

View file

@ -34,12 +34,10 @@ type BackfillQueueConfig struct {
MaxBatchesOverride map[string]int `yaml:"max_batches_override"`
}
func (bqc *BackfillQueueConfig) GetOverride(names ...string) int {
for _, name := range names {
override, ok := bqc.MaxBatchesOverride[name]
if ok {
return override
}
func (bqc *BackfillQueueConfig) GetOverride(name string) int {
override, ok := bqc.MaxBatchesOverride[name]
if !ok {
return bqc.MaxBatches
}
return bqc.MaxBatches
return override
}

View file

@ -33,8 +33,6 @@ type Config struct {
Encryption EncryptionConfig `yaml:"encryption"`
Logging zeroconfig.Config `yaml:"logging"`
EnvConfigPrefix string `yaml:"env_config_prefix"`
ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"`
}
@ -62,40 +60,36 @@ type CleanupOnLogouts struct {
}
type BridgeConfig struct {
CommandPrefix string `yaml:"command_prefix"`
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
AsyncEvents bool `yaml:"async_events"`
SplitPortals bool `yaml:"split_portals"`
ResendBridgeInfo bool `yaml:"resend_bridge_info"`
NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
BridgeStatusNotices string `yaml:"bridge_status_notices"`
UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"`
BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
BridgeNotices bool `yaml:"bridge_notices"`
TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
CrossRoomReplies bool `yaml:"cross_room_replies"`
OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"`
KickMatrixUsers bool `yaml:"kick_matrix_users"`
CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
Relay RelayConfig `yaml:"relay"`
Permissions PermissionConfig `yaml:"permissions"`
Backfill BackfillConfig `yaml:"backfill"`
CommandPrefix string `yaml:"command_prefix"`
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
AsyncEvents bool `yaml:"async_events"`
SplitPortals bool `yaml:"split_portals"`
ResendBridgeInfo bool `yaml:"resend_bridge_info"`
NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
BridgeStatusNotices string `yaml:"bridge_status_notices"`
UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
BridgeNotices bool `yaml:"bridge_notices"`
TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
CrossRoomReplies bool `yaml:"cross_room_replies"`
OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
Relay RelayConfig `yaml:"relay"`
Permissions PermissionConfig `yaml:"permissions"`
Backfill BackfillConfig `yaml:"backfill"`
}
type MatrixConfig struct {
MessageStatusEvents bool `yaml:"message_status_events"`
DeliveryReceipts bool `yaml:"delivery_receipts"`
MessageErrorNotices bool `yaml:"message_error_notices"`
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
FederateRooms bool `yaml:"federate_rooms"`
UploadFileThreshold int64 `yaml:"upload_file_threshold"`
GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"`
MessageStatusEvents bool `yaml:"message_status_events"`
DeliveryReceipts bool `yaml:"delivery_receipts"`
MessageErrorNotices bool `yaml:"message_error_notices"`
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
FederateRooms bool `yaml:"federate_rooms"`
UploadFileThreshold int64 `yaml:"upload_file_threshold"`
}
type AnalyticsConfig struct {
@ -105,6 +99,7 @@ type AnalyticsConfig struct {
}
type ProvisioningConfig struct {
Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"`
DebugEndpoints bool `yaml:"debug_endpoints"`
EnableSessionTransfers bool `yaml:"enable_session_transfers"`
@ -117,12 +112,10 @@ type DirectMediaConfig struct {
}
type PublicMediaConfig struct {
Enabled bool `yaml:"enabled"`
SigningKey string `yaml:"signing_key"`
Expiry int `yaml:"expiry"`
HashLength int `yaml:"hash_length"`
PathPrefix string `yaml:"path_prefix"`
UseDatabase bool `yaml:"use_database"`
Enabled bool `yaml:"enabled"`
SigningKey string `yaml:"signing_key"`
HashLength int `yaml:"hash_length"`
Expiry int `yaml:"expiry"`
}
type DoublePuppetConfig struct {

View file

@ -16,8 +16,6 @@ type EncryptionConfig struct {
Require bool `yaml:"require"`
Appservice bool `yaml:"appservice"`
MSC4190 bool `yaml:"msc4190"`
MSC4392 bool `yaml:"msc4392"`
SelfSign bool `yaml:"self_sign"`
PlaintextMentions bool `yaml:"plaintext_mentions"`

View file

@ -133,7 +133,9 @@ func doMigrateLegacy(helper up.Helper, python bool) {
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"})
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"})
CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"})

View file

@ -24,7 +24,6 @@ type Permissions struct {
DoublePuppet bool `yaml:"double_puppet"`
Admin bool `yaml:"admin"`
ManageRelay bool `yaml:"manage_relay"`
MaxLogins int `yaml:"max_logins"`
}
type PermissionConfig map[string]*Permissions
@ -41,7 +40,10 @@ func (pc PermissionConfig) IsConfigured() bool {
_, hasExampleDomain := pc["example.com"]
_, hasExampleUser := pc["@admin:example.com"]
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
return len(pc) > exampleLen
if len(pc) <= exampleLen {
return false
}
return true
}
func (pc PermissionConfig) Get(userID id.UserID) Permissions {

View file

@ -33,7 +33,6 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key")
helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices")
helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect")
helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects")
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
helper.Copy(up.Bool, "bridge", "bridge_notices")
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
@ -41,8 +40,6 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "bridge", "mute_only_on_create")
helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages")
helper.Copy(up.Bool, "bridge", "cross_room_replies")
helper.Copy(up.Bool, "bridge", "revert_failed_state_changes")
helper.Copy(up.Bool, "bridge", "kick_matrix_users")
helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled")
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private")
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed")
@ -101,12 +98,12 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
helper.Copy(up.Bool, "matrix", "federate_rooms")
helper.Copy(up.Int, "matrix", "upload_file_threshold")
helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info")
helper.Copy(up.Str|up.Null, "analytics", "token")
helper.Copy(up.Str|up.Null, "analytics", "url")
helper.Copy(up.Str|up.Null, "analytics", "user_id")
helper.Copy(up.Str, "provisioning", "prefix")
if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := random.String(64)
helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret")
@ -136,8 +133,6 @@ func doUpgrade(helper up.Helper) {
}
helper.Copy(up.Int, "public_media", "expiry")
helper.Copy(up.Int, "public_media", "hash_length")
helper.Copy(up.Str|up.Null, "public_media", "path_prefix")
helper.Copy(up.Bool, "public_media", "use_database")
helper.Copy(up.Bool, "backfill", "enabled")
helper.Copy(up.Int, "backfill", "max_initial_messages")
@ -163,8 +158,6 @@ func doUpgrade(helper up.Helper) {
} else {
helper.Copy(up.Bool, "encryption", "msc4190")
}
helper.Copy(up.Bool, "encryption", "msc4392")
helper.Copy(up.Bool, "encryption", "self_sign")
helper.Copy(up.Bool, "encryption", "allow_key_sharing")
if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" {
helper.Set(up.Str, random.String(64), "encryption", "pickle_key")
@ -187,8 +180,6 @@ func doUpgrade(helper up.Helper) {
helper.Copy(up.Int, "encryption", "rotation", "messages")
helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation")
helper.Copy(up.Str|up.Null, "env_config_prefix")
helper.Copy(up.Map, "logging")
}
@ -216,7 +207,6 @@ var SpacedBlocks = [][]string{
{"backfill"},
{"double_puppet"},
{"encryption"},
{"env_config_prefix"},
{"logging"},
}

View file

@ -15,15 +15,12 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/exfmt"
"maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
)
var CatchBridgeStateQueuePanics = true
type BridgeStateQueue struct {
prevUnsent *status.BridgeState
prevSent *status.BridgeState
@ -32,13 +29,8 @@ type BridgeStateQueue struct {
bridge *Bridge
login *UserLogin
firstTransientDisconnect time.Time
cancelScheduledNotice atomic.Pointer[context.CancelFunc]
stopChan chan struct{}
stopReconnect atomic.Pointer[context.CancelFunc]
unknownErrorReconnects int
}
func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) {
@ -82,63 +74,31 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() {
if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil {
(*cancelFn)()
}
if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil {
(*cancelFn)()
}
}
func (bsq *BridgeStateQueue) loop() {
if CatchBridgeStateQueuePanics {
defer func() {
err := recover()
if err != nil {
bsq.login.Log.Error().
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
Any(zerolog.ErrorFieldName, err).
Msg("Panic in bridge state loop")
}
}()
}
defer func() {
err := recover()
if err != nil {
bsq.login.Log.Error().
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
Any(zerolog.ErrorFieldName, err).
Msg("Panic in bridge state loop")
}
}()
for state := range bsq.ch {
bsq.immediateSendBridgeState(state)
}
}
func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
ctx := log.WithContext(bsq.bridge.BackgroundCtx)
if !bsq.waitForTransientDisconnectReconnect(ctx) {
return
}
prevUnsent := bsq.GetPrevUnsent()
prev := bsq.GetPrev()
if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent ||
prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect {
log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice")
return
}
log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice")
bsq.sendNotice(ctx, triggeredBy, true)
}
func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) {
func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) {
noticeConfig := bsq.bridge.Config.BridgeStatusNotices
isError := state.StateEvent == status.StateBadCredentials ||
state.StateEvent == status.StateUnknownError ||
state.UserAction == status.UserActionOpenNative ||
(isDelayed && state.StateEvent == status.StateTransientDisconnect)
state.UserAction == status.UserActionOpenNative
sendNotice := noticeConfig == "all" || (noticeConfig == "errors" &&
(isError || (bsq.errorSent && state.StateEvent == status.StateConnected)))
if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError {
bsq.firstTransientDisconnect = time.Time{}
}
if !sendNotice {
if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect {
if bsq.firstTransientDisconnect.IsZero() {
bsq.firstTransientDisconnect = time.Now()
}
go bsq.scheduleNotice(state)
}
return
}
managementRoom, err := bsq.login.User.GetManagementRoom(ctx)
@ -154,9 +114,6 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
if state.Error != "" {
message += fmt.Sprintf(" (`%s`)", state.Error)
}
if isDelayed {
message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay))
}
if state.Message != "" {
message += fmt.Sprintf(": %s", state.Message)
}
@ -194,14 +151,8 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat
} else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError {
log.Debug().Msg("Not reconnecting as the previous state was not an unknown error")
return
} else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects {
log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached")
return
}
bsq.unknownErrorReconnects++
log.Info().
Int("reconnect_num", bsq.unknownErrorReconnects).
Msg("Disconnecting and reconnecting login due to unknown error")
log.Info().Msg("Disconnecting and reconnecting login due to unknown error")
bsq.login.Disconnect()
log.Debug().Msg("Disconnection finished, recreating client and reconnecting")
err := bsq.login.recreateClient(ctx)
@ -220,30 +171,14 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b
return false
}
reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2))
return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect)
}
const TransientDisconnectNoticeDelay = 3 * time.Minute
func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool {
timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay))
zerolog.Ctx(ctx).Trace().
Stringer("duration", timeUntilSchedule).
Msg("Waiting before sending notice about transient disconnect")
return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice)
}
func (bsq *BridgeStateQueue) waitForReconnect(
ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc],
) bool {
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
if oldCancel := ptr.Swap(&cancel); oldCancel != nil {
if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil {
(*oldCancel)()
}
select {
case <-time.After(reconnectIn):
return ptr.CompareAndSwap(&cancel, nil)
return bsq.stopReconnect.CompareAndSwap(&cancel, nil)
case <-cancelCtx.Done():
return false
case <-bsq.stopChan:
@ -263,7 +198,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState)
}
ctx := bsq.login.Log.WithContext(context.Background())
bsq.sendNotice(ctx, state, false)
bsq.sendNotice(ctx, state)
retryIn := 2
for {

View file

@ -7,13 +7,10 @@
package commands
import (
"encoding/json"
"strings"
"time"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
)
var CommandRegisterPush = &FullHandler{
@ -62,64 +59,3 @@ var CommandRegisterPush = &FullHandler{
RequiresLogin: true,
NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI],
}
var CommandSendAccountData = &FullHandler{
Func: func(ce *Event) {
if len(ce.Args) < 2 {
ce.Reply("Usage: `$cmdprefix debug-account-data <type> <content>")
return
}
var content event.Content
evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType}
ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0]))
err := json.Unmarshal([]byte(ce.RawArgs), &content)
if err != nil {
ce.Reply("Failed to parse JSON: %v", err)
return
}
err = content.ParseRaw(evtType)
if err != nil {
ce.Reply("Failed to deserialize content: %v", err)
return
}
res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{
Sender: ce.User.MXID,
Type: evtType,
Timestamp: time.Now().UnixMilli(),
RoomID: ce.RoomID,
Content: content,
})
ce.Reply("Result: %+v", res)
},
Name: "debug-account-data",
Help: HelpMeta{
Section: HelpSectionAdmin,
Description: "Send a room account data event to the bridge",
Args: "<_type_> <_content_>",
},
RequiresAdmin: true,
RequiresPortal: true,
RequiresLogin: true,
}
var CommandResetNetwork = &FullHandler{
Func: func(ce *Event) {
if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") {
nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork)
if ok {
nrn.ResetHTTPTransport()
} else {
ce.Reply("Network connector does not support resetting HTTP transport")
}
}
ce.Bridge.ResetNetworkConnections()
ce.React("✅️")
},
Name: "debug-reset-network",
Help: HelpMeta{
Section: HelpSectionAdmin,
Description: "Reset network connections to the remote network",
Args: "[--reset-transport]",
},
RequiresAdmin: true,
}

View file

@ -70,15 +70,6 @@ func fnLogin(ce *Event) {
}
ce.Args = ce.Args[1:]
}
if reauth == nil && ce.User.HasTooManyLogins() {
ce.Reply(
"You have reached the maximum number of logins (%d). "+
"Please logout from an existing login before creating a new one. "+
"If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.",
ce.User.Permissions.MaxLogins,
)
return
}
flows := ce.Bridge.Network.GetLoginFlows()
var chosenFlowID string
if len(ce.Args) > 0 {
@ -121,7 +112,6 @@ func fnLogin(ce *Event) {
ce.Reply("Failed to start login: %v", err)
return
}
ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process")
nextStep = checkLoginCommandDirectParams(ce, login, nextStep)
if nextStep != nil {
@ -200,14 +190,11 @@ type userInputLoginCommandState struct {
func (uilcs *userInputLoginCommandState) promptNext(ce *Event) {
field := uilcs.RemainingFields[0]
parts := []string{fmt.Sprintf("Please enter your %s", field.Name)}
if field.Description != "" {
parts = append(parts, field.Description)
ce.Reply("Please enter your %s\n%s", field.Name, field.Description)
} else {
ce.Reply("Please enter your %s", field.Name)
}
if len(field.Options) > 0 {
parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `")))
}
ce.Reply(strings.Join(parts, "\n"))
StoreCommandState(ce.User, &CommandState{
Next: MinimalCommandHandlerFunc(uilcs.submitNext),
Action: "Login",
@ -252,19 +239,14 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return fmt.Errorf("failed to upload image: %w", err)
}
content := &event.MessageEventContent{
MsgType: event.MsgImage,
FileName: "qr.png",
URL: qrMXC,
File: qrFile,
MsgType: event.MsgImage,
FileName: "qr.png",
URL: qrMXC,
File: qrFile,
Body: qr,
Format: event.FormatHTML,
FormattedBody: fmt.Sprintf("<pre><code>%s</code></pre>", html.EscapeString(qr)),
Info: &event.FileInfo{
MimeType: "image/png",
Width: qrSizePx,
Height: qrSizePx,
Size: len(qrData),
},
}
if *prevEventID != "" {
content.SetEdit(*prevEventID)
@ -279,36 +261,6 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
return nil
}
func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
for _, att := range atts {
if att.FileName == "" {
return fmt.Errorf("missing attachment filename")
}
mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
if err != nil {
return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
}
content := &event.MessageEventContent{
MsgType: att.Type,
FileName: att.FileName,
URL: mxc,
File: file,
Info: &event.FileInfo{
MimeType: att.Info.MimeType,
Width: att.Info.Width,
Height: att.Info.Height,
Size: att.Info.Size,
},
Body: att.FileName,
}
_, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
if err != nil {
return nil
}
}
return nil
}
type contextKey int
const (
@ -500,7 +452,6 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
}
func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
if step.Instructions != "" {
ce.Reply(step.Instructions)
}
@ -515,10 +466,6 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte
Override: override,
}).prompt(ce)
case bridgev2.LoginStepTypeUserInput:
err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
if err != nil {
ce.Reply("Failed to send attachments: %v", err)
}
(&userInputLoginCommandState{
Login: login.(bridgev2.LoginProcessUserInput),
RemainingFields: step.UserInputParams.Fields,

View file

@ -41,11 +41,10 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
}
proc.AddHandlers(
CommandHelp, CommandCancel,
CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
CommandSetRelay, CommandUnsetRelay,
CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
CommandResolveIdentifier, CommandStartChat, CommandSearch,
CommandSudo, CommandDoIn,
)
return proc

View file

@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
}
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
var relay *bridgev2.UserLogin
if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
if len(ce.Args) == 0 {
relay = ce.User.GetDefaultLogin()
isLoggedIn := relay != nil
if onlySetDefaultRelays {
@ -73,19 +73,9 @@ func fnSetRelay(ce *Event) {
}
}
} else {
var targetID networkid.UserLoginID
if ce.Portal.Receiver != "" {
targetID = ce.Portal.Receiver
if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
return
}
} else {
targetID = networkid.UserLoginID(ce.Args[0])
}
relay = ce.Bridge.GetCachedUserLoginByID(targetID)
relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
if relay == nil {
ce.Reply("User login with ID `%s` not found", targetID)
ce.Reply("User login with ID `%s` not found", ce.Args[0])
return
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
// All good

View file

@ -1,4 +1,4 @@
// Copyright (c) 2025 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,21 +8,13 @@ package commands
import (
"context"
"errors"
"fmt"
"html"
"maps"
"slices"
"strings"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/bridgev2/provisionutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
@ -38,35 +30,6 @@ var CommandResolveIdentifier = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
var CommandSyncChat = &FullHandler{
Func: func(ce *Event) {
login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false)
if err != nil {
ce.Log.Err(err).Msg("Failed to find login for sync")
ce.Reply("Failed to find login: %v", err)
return
} else if login == nil {
ce.Reply("No login found for sync")
return
}
info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal)
if err != nil {
ce.Log.Err(err).Msg("Failed to get chat info for sync")
ce.Reply("Failed to get chat info: %v", err)
return
}
ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{})
ce.React("✅️")
},
Name: "sync-portal",
Help: HelpMeta{
Section: HelpSectionChats,
Description: "Sync the current portal room",
},
RequiresPortal: true,
RequiresLogin: true,
}
var CommandStartChat = &FullHandler{
Func: fnResolveIdentifier,
Name: "start-chat",
@ -80,15 +43,9 @@ var CommandStartChat = &FullHandler{
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
}
func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
var remainingArgs []string
if len(ce.Args) > 1 {
remainingArgs = ce.Args[1:]
}
var login *bridgev2.UserLogin
if len(ce.Args) > 0 {
login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
}
func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
remainingArgs := ce.Args[1:]
login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
if login == nil || login.UserMXID != ce.User.MXID {
remainingArgs = ce.Args
login = ce.User.GetDefaultLogin()
@ -100,13 +57,24 @@ func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*
return login, api, remainingArgs
}
func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string {
if resp.MXID != "" {
return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL())
} else if resp.Name != "" {
return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name)
func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string {
var targetName string
var targetMXID id.UserID
if resp.Ghost != nil {
if resp.UserInfo != nil {
resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
}
targetName = resp.Ghost.Name
targetMXID = resp.Ghost.Intent.GetMXID()
} else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
targetName = *resp.UserInfo.Name
}
if targetMXID != "" {
return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL())
} else if targetName != "" {
return fmt.Sprintf("`%s` / %s", resp.UserID, targetName)
} else {
return fmt.Sprintf("`%s`", resp.ID)
return fmt.Sprintf("`%s`", resp.UserID)
}
}
@ -119,137 +87,65 @@ func fnResolveIdentifier(ce *Event) {
if api == nil {
return
}
allLogins := ce.User.GetUserLogins()
createChat := ce.Command == "start-chat" || ce.Command == "pm"
identifier := strings.Join(identifierParts, " ")
resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat)
for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ {
resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat)
}
resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat)
if err != nil {
ce.Log.Err(err).Msg("Failed to resolve identifier")
ce.Reply("Failed to resolve identifier: %v", err)
return
} else if resp == nil {
ce.ReplyAdvanced(fmt.Sprintf("Identifier <code>%s</code> not found", html.EscapeString(identifier)), false, true)
return
}
formattedName := formatResolveIdentifierResult(resp)
formattedName := formatResolveIdentifierResult(ce.Ctx, resp)
if createChat {
name := resp.Portal.Name
if name == "" {
name = resp.Portal.MXID.String()
if resp.Chat == nil {
ce.Reply("Interface error: network connector did not return chat for create chat request")
return
}
if !resp.JustCreated {
ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
portal := resp.Chat.Portal
if portal == nil {
portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey)
if err != nil {
ce.Log.Err(err).Msg("Failed to get portal")
ce.Reply("Failed to get portal: %v", err)
return
}
}
if resp.Chat.PortalInfo == nil {
resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal)
if err != nil {
ce.Log.Err(err).Msg("Failed to get portal info")
ce.Reply("Failed to get portal info: %v", err)
return
}
}
if portal.MXID != "" {
name := portal.Name
if name == "" {
name = portal.MXID.String()
}
portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{})
ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
} else {
ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo)
if err != nil {
ce.Log.Err(err).Msg("Failed to create room")
ce.Reply("Failed to create room: %v", err)
return
}
name := portal.Name
if name == "" {
name = portal.MXID.String()
}
ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
}
} else {
ce.Reply("Found %s", formattedName)
}
}
var CommandCreateGroup = &FullHandler{
Func: fnCreateGroup,
Name: "create-group",
Aliases: []string{"create"},
Help: HelpMeta{
Section: HelpSectionChats,
Description: "Create a new group chat for the current Matrix room",
Args: "[_group type_]",
},
RequiresLogin: true,
NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
}
func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
if err != nil {
zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
} else if evt != nil {
content, _ = evt.Content.Parsed.(T)
}
return
}
func fnCreateGroup(ce *Event) {
ce.Bridge.Matrix.GetCapabilities()
login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
if api == nil {
return
}
stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
if !ok {
ce.Reply("Matrix connector doesn't support fetching room state")
return
}
members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
if err != nil {
ce.Log.Err(err).Msg("Failed to get room members for group creation")
ce.Reply("Failed to get room members: %v", err)
return
}
caps := ce.Bridge.Network.GetCapabilities()
params := &bridgev2.GroupCreateParams{
Username: "",
Participants: make([]networkid.UserID, 0, len(members)-2),
Parent: nil, // TODO check space parent event
Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
RoomID: ce.RoomID,
}
for userID, member := range members {
if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
continue
}
if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
params.Participants = append(params.Participants, parsedUserID)
} else if !ce.Bridge.Config.SplitPortals {
if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
} else if user != nil {
// TODO add user logins to participants
//for _, login := range user.GetUserLogins() {
// params.Participants = append(params.Participants, login.GetUserID())
//}
}
}
}
if len(caps.Provisioning.GroupCreation) == 0 {
ce.Reply("No group creation types defined in network capabilities")
return
} else if len(remainingArgs) > 0 {
params.Type = remainingArgs[0]
} else if len(caps.Provisioning.GroupCreation) == 1 {
for params.Type = range caps.Provisioning.GroupCreation {
// The loop assigns the variable we want
}
} else {
types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
ce.Reply("Please specify type of group to create: `%s`", types)
return
}
resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
if err != nil {
ce.Reply("Failed to create group: %v", err)
return
}
var postfix string
if len(resp.FailedParticipants) > 0 {
failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
i := 0
for participantID, meta := range resp.FailedParticipants {
failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
i++
}
postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
}
ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
}
var CommandSearch = &FullHandler{
Func: fnSearch,
Name: "search",
@ -267,67 +163,35 @@ func fnSearch(ce *Event) {
ce.Reply("Usage: `$cmdprefix search <query>`")
return
}
login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
_, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
if api == nil {
return
}
resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " "))
results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " "))
if err != nil {
ce.Log.Err(err).Msg("Failed to search for users")
ce.Reply("Failed to search for users: %v", err)
return
}
resultsString := make([]string, len(resp.Results))
for i, res := range resp.Results {
formattedName := formatResolveIdentifierResult(res)
resultsString := make([]string, len(results))
for i, res := range results {
formattedName := formatResolveIdentifierResult(ce.Ctx, res)
resultsString[i] = fmt.Sprintf("* %s", formattedName)
if res.Portal != nil && res.Portal.MXID != "" {
portalName := res.Portal.Name
if portalName == "" {
portalName = res.Portal.MXID.String()
if res.Chat != nil {
if res.Chat.Portal == nil {
res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey)
if err != nil {
ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal")
}
}
if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" {
portalName := res.Chat.Portal.Name
if portalName == "" {
portalName = res.Chat.Portal.MXID.String()
}
resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL())
}
resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL())
}
}
ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n"))
}
var CommandMute = &FullHandler{
Func: fnMute,
Name: "mute",
Aliases: []string{"unmute"},
Help: HelpMeta{
Section: HelpSectionChats,
Description: "Mute or unmute a chat on the remote network",
Args: "[duration]",
},
RequiresPortal: true,
RequiresLogin: true,
NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI],
}
func fnMute(ce *Event) {
_, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats")
var mutedUntil int64
if ce.Command == "mute" {
mutedUntil = -1
if len(ce.Args) > 0 {
duration, err := time.ParseDuration(ce.Args[0])
if err != nil {
ce.Reply("Invalid duration: %v", err)
return
}
mutedUntil = time.Now().Add(duration).UnixMilli()
}
}
err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{
MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{
Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil},
Portal: ce.Portal,
},
})
if err != nil {
ce.Reply("Failed to %s chat: %v", ce.Command, err)
} else {
ce.React("✅️")
}
}

View file

@ -7,7 +7,13 @@
package database
import (
"encoding/json"
"reflect"
"strings"
"go.mau.fi/util/dbutil"
"golang.org/x/exp/constraints"
"golang.org/x/exp/maps"
"maunium.net/go/mautrix/bridgev2/networkid"
@ -28,7 +34,6 @@ type Database struct {
UserPortal *UserPortalQuery
BackfillTask *BackfillTaskQuery
KV *KVQuery
PublicMedia *PublicMediaQuery
}
type MetaMerger interface {
@ -136,12 +141,6 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa
BridgeID: bridgeID,
Database: db,
},
PublicMedia: &PublicMediaQuery{
BridgeID: bridgeID,
QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia {
return &PublicMedia{}
}),
},
}
}
@ -152,3 +151,55 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
panic("bridge ID mismatch")
}
}
func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
if val, found := m[key]; found {
floatVal, ok := val.(float64)
if ok {
return T(floatVal), true
}
tVal, ok := val.(T)
if ok {
return tVal, true
}
}
return 0, false
}
func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
err := json.Unmarshal(input, data)
if err != nil {
return err
}
err = json.Unmarshal(input, extra)
if err != nil {
return err
}
if *extra == nil {
*extra = make(map[string]any)
}
return nil
}
func marshalMerge(data any, extra map[string]any) ([]byte, error) {
if extra == nil {
return json.Marshal(data)
}
merged := make(map[string]any)
maps.Copy(merged, extra)
dataRef := reflect.ValueOf(data).Elem()
dataType := dataRef.Type()
for _, field := range reflect.VisibleFields(dataType) {
parts := strings.Split(field.Tag.Get("json"), ",")
if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
continue
}
fieldVal := dataRef.FieldByIndex(field.Index)
if fieldVal.IsZero() {
delete(merged, parts[0])
} else {
merged[parts[0]] = fieldVal.Interface()
}
}
return json.Marshal(merged)
}

View file

@ -12,92 +12,54 @@ import (
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// Deprecated: use [event.DisappearingType]
type DisappearingType = event.DisappearingType
// DisappearingType represents the type of a disappearing message timer.
type DisappearingType string
// Deprecated: use constants in event package
const (
DisappearingTypeNone = event.DisappearingTypeNone
DisappearingTypeAfterRead = event.DisappearingTypeAfterRead
DisappearingTypeAfterSend = event.DisappearingTypeAfterSend
DisappearingTypeNone DisappearingType = ""
DisappearingTypeAfterRead DisappearingType = "after_read"
DisappearingTypeAfterSend DisappearingType = "after_send"
)
// DisappearingSetting represents a disappearing message timer setting
// by combining a type with a timer and an optional start timestamp.
type DisappearingSetting struct {
Type event.DisappearingType
Type DisappearingType
Timer time.Duration
DisappearAt time.Time
}
func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting {
if evt == nil || evt.Type == event.DisappearingTypeNone {
return DisappearingSetting{}
}
return DisappearingSetting{
Type: evt.Type,
Timer: evt.Timer.Duration,
}
}
func (ds DisappearingSetting) Normalize() DisappearingSetting {
if ds.Type == event.DisappearingTypeNone {
ds.Timer = 0
} else if ds.Timer == 0 {
ds.Type = event.DisappearingTypeNone
}
return ds
}
func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting {
ds.DisappearAt = start.Add(ds.Timer)
return ds
}
func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer {
if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 {
return &event.BeeperDisappearingTimer{}
}
return &event.BeeperDisappearingTimer{
Type: ds.Type,
Timer: jsontime.MS(ds.Timer),
}
}
type DisappearingMessageQuery struct {
BridgeID networkid.BridgeID
*dbutil.QueryHelper[*DisappearingMessage]
}
type DisappearingMessage struct {
BridgeID networkid.BridgeID
RoomID id.RoomID
EventID id.EventID
Timestamp time.Time
BridgeID networkid.BridgeID
RoomID id.RoomID
EventID id.EventID
DisappearingSetting
}
const (
upsertDisappearingMessageQuery = `
INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at
`
startDisappearingMessagesQuery = `
UPDATE disappearing_message
SET disappear_at=$1 + timer
WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4
RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read'
RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at
`
getUpcomingDisappearingMessagesQuery = `
SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
SELECT bridge_id, mx_room, mxid, type, timer, disappear_at
FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2
ORDER BY disappear_at LIMIT $3
`
@ -111,8 +73,8 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe
return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...)
}
func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano())
func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID)
}
func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) {
@ -124,19 +86,17 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even
}
func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
var timestamp int64
var disappearAt sql.NullInt64
err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &timestamp, &d.Type, &d.Timer, &disappearAt)
err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt)
if err != nil {
return nil, err
}
if disappearAt.Valid {
d.DisappearAt = time.Unix(0, disappearAt.Int64)
}
d.Timestamp = time.Unix(0, timestamp)
return d, nil
}
func (d *DisappearingMessage) sqlVariables() []any {
return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
}

View file

@ -7,17 +7,12 @@
package database
import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
"fmt"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
)
@ -27,55 +22,6 @@ type GhostQuery struct {
*dbutil.QueryHelper[*Ghost]
}
type ExtraProfile map[string]json.RawMessage
func (ep *ExtraProfile) Set(key string, value any) error {
if key == "displayname" || key == "avatar_url" {
return fmt.Errorf("cannot set reserved profile key %q", key)
}
marshaled, err := json.Marshal(value)
if err != nil {
return err
}
if *ep == nil {
*ep = make(ExtraProfile)
}
(*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled)
return nil
}
func (ep *ExtraProfile) With(key string, value any) *ExtraProfile {
exerrors.PanicIfNotNil(ep.Set(key, value))
return ep
}
func canonicalizeIfObject(data json.RawMessage) json.RawMessage {
if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
return canonicaljson.CanonicalJSONAssumeValid(data)
}
return data
}
func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) {
if len(*ep) == 0 {
return
}
if *dest == nil {
*dest = make(ExtraProfile)
}
for key, val := range *ep {
if key == "displayname" || key == "avatar_url" {
continue
}
existing, exists := (*dest)[key]
if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) {
(*dest)[key] = val
changed = true
}
}
return
}
type Ghost struct {
BridgeID networkid.BridgeID
ID networkid.UserID
@ -89,14 +35,13 @@ type Ghost struct {
ContactInfoSet bool
IsBot bool
Identifiers []string
ExtraProfile ExtraProfile
Metadata any
}
const (
getGhostBaseQuery = `
SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
FROM ghost
`
getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2`
@ -104,14 +49,13 @@ const (
insertGhostQuery = `
INSERT INTO ghost (
bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
`
updateGhostQuery = `
UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6,
name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10,
identifiers=$11, extra_profile=$12, metadata=$13
name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12
WHERE bridge_id=$1 AND id=$2
`
)
@ -142,7 +86,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) {
&g.BridgeID, &g.ID,
&g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC,
&g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot,
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
)
if err != nil {
return nil, err
@ -172,6 +116,6 @@ func (g *Ghost) sqlVariables() []any {
g.BridgeID, g.ID,
g.Name, g.AvatarID, avatarHash, g.AvatarMXC,
g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot,
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
}
}

View file

@ -23,7 +23,6 @@ const (
KeySplitPortalsEnabled Key = "split_portals_enabled"
KeyBridgeInfoVersion Key = "bridge_info_version"
KeyEncryptionStateResynced Key = "encryption_state_resynced"
KeyRecoveryKey Key = "recovery_key"
)
type KVQuery struct {

View file

@ -11,12 +11,9 @@ import (
"crypto/sha256"
"database/sql"
"encoding/base64"
"fmt"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
@ -27,7 +24,6 @@ type MessageQuery struct {
BridgeID networkid.BridgeID
MetaType MetaTypeCreator
*dbutil.QueryHelper[*Message]
chunkDeleteLock sync.Mutex
}
type Message struct {
@ -68,8 +64,8 @@ const (
getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`
getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5`
getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1`
getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1`
getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1`
getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1`
getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1`
getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4`
getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1`
@ -100,10 +96,6 @@ const (
deleteMessagePartByRowIDQuery = `
DELETE FROM message WHERE bridge_id=$1 AND rowid=$2
`
deleteMessageChunkQuery = `
DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5
`
getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1`
)
func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) {
@ -188,85 +180,6 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error {
return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID)
}
func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) {
res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) {
err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID)
return
}
const deleteChunkSize = 100_000
func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error {
if mq.GetDB().Dialect != dbutil.SQLite {
return nil
}
log := zerolog.Ctx(ctx).With().
Str("action", "delete messages in chunks").
Stringer("portal_key", portal).
Logger()
if !mq.chunkDeleteLock.TryLock() {
log.Warn().Msg("Portal deletion lock is being held, waiting...")
mq.chunkDeleteLock.Lock()
log.Debug().Msg("Acquired portal deletion lock after waiting")
}
defer mq.chunkDeleteLock.Unlock()
total, err := mq.CountMessagesInPortal(ctx, portal)
if err != nil {
return fmt.Errorf("failed to count messages in portal: %w", err)
} else if total < deleteChunkSize/3 {
return nil
}
globalMaxRowID, err := mq.getMaxRowID(ctx)
if err != nil {
return fmt.Errorf("failed to get max row ID: %w", err)
}
log.Debug().
Int("total_count", total).
Int64("global_max_row_id", globalMaxRowID).
Msg("Portal has lots of messages, deleting in chunks to avoid database locks")
maxRowID := int64(deleteChunkSize)
globalMaxRowID += deleteChunkSize * 1.2
var dbTimeUsed time.Duration
globalStart := time.Now()
for total > 500 && maxRowID < globalMaxRowID {
start := time.Now()
count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID)
duration := time.Since(start)
dbTimeUsed += duration
if err != nil {
return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err)
}
total -= int(count)
maxRowID += deleteChunkSize
sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond))
log.Debug().
Int64("max_row_id", maxRowID).
Int64("deleted_count", count).
Int("remaining_count", total).
Dur("duration", duration).
Dur("sleep_time", sleepTime).
Msg("Deleted chunk of messages")
select {
case <-time.After(sleepTime):
case <-ctx.Done():
return ctx.Err()
}
}
log.Debug().
Int("remaining_count", total).
Dur("db_time_used", dbTimeUsed).
Dur("total_duration", time.Since(globalStart)).
Msg("Finished chunked delete of messages in portal")
return nil
}
func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) {
err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count)
return

View file

@ -16,7 +16,6 @@ import (
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -35,20 +34,9 @@ type PortalQuery struct {
*dbutil.QueryHelper[*Portal]
}
type CapStateFlags uint32
func (csf CapStateFlags) Has(flag CapStateFlags) bool {
return csf&flag != 0
}
const (
CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota
)
type CapabilityState struct {
Source networkid.UserLoginID `json:"source"`
ID string `json:"id"`
Flags CapStateFlags `json:"flags"`
}
type Portal struct {
@ -56,31 +44,30 @@ type Portal struct {
networkid.PortalKey
MXID id.RoomID
ParentKey networkid.PortalKey
RelayLoginID networkid.UserLoginID
OtherUserID networkid.UserID
Name string
Topic string
AvatarID networkid.AvatarID
AvatarHash [32]byte
AvatarMXC id.ContentURIString
NameSet bool
TopicSet bool
AvatarSet bool
NameIsCustom bool
InSpace bool
MessageRequest bool
RoomType RoomType
Disappear DisappearingSetting
CapState CapabilityState
Metadata any
ParentKey networkid.PortalKey
RelayLoginID networkid.UserLoginID
OtherUserID networkid.UserID
Name string
Topic string
AvatarID networkid.AvatarID
AvatarHash [32]byte
AvatarMXC id.ContentURIString
NameSet bool
TopicSet bool
AvatarSet bool
NameIsCustom bool
InSpace bool
RoomType RoomType
Disappear DisappearingSetting
CapState CapabilityState
Metadata any
}
const (
getPortalBaseQuery = `
SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
name_set, topic_set, avatar_set, name_is_custom, in_space, message_request,
name_set, topic_set, avatar_set, name_is_custom, in_space,
room_type, disappear_type, disappear_timer, cap_state,
metadata
FROM portal
@ -89,9 +76,7 @@ const (
getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')`
getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL`
getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC`
getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2`
getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3`
getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1`
getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3`
@ -102,11 +87,11 @@ const (
bridge_id, id, receiver, mxid,
parent_id, parent_receiver, relay_login_id, other_user_id,
name, topic, avatar_id, avatar_hash, avatar_mxc,
name_set, avatar_set, topic_set, name_is_custom, in_space, message_request,
name_set, avatar_set, topic_set, name_is_custom, in_space,
room_type, disappear_type, disappear_timer, cap_state,
metadata, relay_bridge_id
) VALUES (
$1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24,
$1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23,
CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END
)
`
@ -115,8 +100,8 @@ const (
SET mxid=$4, parent_id=$5, parent_receiver=$6,
relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END,
other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13,
name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19,
room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24
name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18,
room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23
WHERE bridge_id=$1 AND id=$2 AND receiver=$3
`
deletePortalQuery = `
@ -126,33 +111,15 @@ const (
reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3`
migrateToSplitPortalsQuery = `
UPDATE portal
SET receiver=new_receiver
FROM (
SELECT bridge_id, id, COALESCE((
SELECT login_id
FROM user_portal
WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
LIMIT 1
), (
SELECT login_id
FROM user_portal
WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id
LIMIT 1
), (
SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
), '') AS new_receiver
FROM portal
WHERE receiver='' AND bridge_id=$1
) updates
WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS (
SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver
)
`
fixParentsAfterSplitPortalMigrationQuery = `
UPDATE portal
SET parent_receiver=receiver
WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''
AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver);
SET receiver=COALESCE((
SELECT login_id
FROM user_portal
WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
LIMIT 1
), (
SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
), '')
WHERE receiver='' AND bridge_id=$1
`
)
@ -180,10 +147,6 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID)
}
func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID)
}
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID)
}
@ -192,10 +155,6 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid.
return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID)
}
func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) {
return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID)
}
func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) {
return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver)
}
@ -226,14 +185,6 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error)
return res.RowsAffected()
}
func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) {
res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString
var disappearTimer sql.NullInt64
@ -242,7 +193,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
&p.BridgeID, &p.ID, &p.Receiver, &mxid,
&parentID, &parentReceiver, &relayLoginID, &otherUserID,
&p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC,
&p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest,
&p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace,
&p.RoomType, &disappearType, &disappearTimer,
dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata},
)
@ -257,7 +208,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
}
if disappearType.Valid {
p.Disappear = DisappearingSetting{
Type: event.DisappearingType(disappearType.String),
Type: DisappearingType(disappearType.String),
Timer: time.Duration(disappearTimer.Int64),
}
}
@ -289,7 +240,7 @@ func (p *Portal) sqlVariables() []any {
p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID),
dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID),
p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC,
p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest,
p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace,
p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer),
dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata},
}

View file

@ -1,72 +0,0 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"time"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/id"
)
type PublicMediaQuery struct {
BridgeID networkid.BridgeID
*dbutil.QueryHelper[*PublicMedia]
}
type PublicMedia struct {
BridgeID networkid.BridgeID
PublicID string
MXC id.ContentURI
Keys *attachment.EncryptedFile
MimeType string
Expiry time.Time
}
const (
upsertPublicMediaQuery = `
INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry
`
getPublicMediaQuery = `
SELECT bridge_id, public_id, mxc, keys, mimetype, expiry
FROM public_media WHERE bridge_id=$1 AND public_id=$2
`
)
func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error {
ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID)
return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...)
}
func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) {
return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID)
}
func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) {
var expiry sql.NullInt64
var mimetype sql.NullString
err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry)
if err != nil {
return nil, err
}
if expiry.Valid {
pm.Expiry = time.Unix(0, expiry.Int64)
}
pm.MimeType = mimetype.String
return pm, nil
}
func (pm *PublicMedia) sqlVariables() []any {
return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)}
}

View file

@ -1,4 +1,4 @@
-- v0 -> v27 (compatible with v9+): Latest revision
-- v0 -> v22 (compatible with v9+): Latest revision
CREATE TABLE "user" (
bridge_id TEXT NOT NULL,
mxid TEXT NOT NULL,
@ -48,7 +48,6 @@ CREATE TABLE portal (
topic_set BOOLEAN NOT NULL,
name_is_custom BOOLEAN NOT NULL DEFAULT false,
in_space BOOLEAN NOT NULL,
message_request BOOLEAN NOT NULL DEFAULT false,
room_type TEXT NOT NULL,
disappear_type TEXT,
disappear_timer BIGINT,
@ -65,7 +64,6 @@ CREATE TABLE portal (
ON DELETE SET NULL ON UPDATE CASCADE
);
CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid);
CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
CREATE TABLE ghost (
bridge_id TEXT NOT NULL,
@ -80,7 +78,6 @@ CREATE TABLE ghost (
contact_info_set BOOLEAN NOT NULL,
is_bot BOOLEAN NOT NULL,
identifiers jsonb NOT NULL,
extra_profile jsonb,
metadata jsonb NOT NULL,
PRIMARY KEY (bridge_id, id)
@ -130,7 +127,6 @@ CREATE TABLE disappearing_message (
bridge_id TEXT NOT NULL,
mx_room TEXT NOT NULL,
mxid TEXT NOT NULL,
timestamp BIGINT NOT NULL DEFAULT 0,
type TEXT NOT NULL,
timer BIGINT NOT NULL,
disappear_at BIGINT,
@ -141,7 +137,6 @@ CREATE TABLE disappearing_message (
REFERENCES portal (bridge_id, mxid)
ON DELETE CASCADE
);
CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
CREATE TABLE reaction (
bridge_id TEXT NOT NULL,
@ -220,14 +215,3 @@ CREATE TABLE kv_store (
PRIMARY KEY (bridge_id, key)
);
CREATE TABLE public_media (
bridge_id TEXT NOT NULL,
public_id TEXT NOT NULL,
mxc TEXT NOT NULL,
keys jsonb,
mimetype TEXT,
expiry BIGINT,
PRIMARY KEY (bridge_id, public_id)
);

View file

@ -1,2 +0,0 @@
-- v23 (compatible with v9+): Add event timestamp for disappearing messages
ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0;

View file

@ -1,11 +0,0 @@
-- v24 (compatible with v9+): Custom URLs for public media
CREATE TABLE public_media (
bridge_id TEXT NOT NULL,
public_id TEXT NOT NULL,
mxc TEXT NOT NULL,
keys jsonb,
mimetype TEXT,
expiry BIGINT,
PRIMARY KEY (bridge_id, public_id)
);

View file

@ -1,2 +0,0 @@
-- v25 (compatible with v9+): Flag for message request portals
ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false;

View file

@ -1,3 +0,0 @@
-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents
CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);

View file

@ -1,2 +0,0 @@
-- v27 (compatible with v9+): Add column for extra ghost profile metadata
ALTER TABLE ghost ADD COLUMN extra_profile jsonb;

View file

@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin {
func (u *UserLogin) sqlVariables() []any {
var remoteProfile dbutil.JSON
if !u.RemoteProfile.IsZero() {
if !u.RemoteProfile.IsEmpty() {
remoteProfile.Data = &u.RemoteProfile
}
return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}}

View file

@ -86,8 +86,8 @@ func (dl *DisappearLoop) Stop() {
}
}
func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) {
startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS)
func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) {
startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages")
return

View file

@ -38,53 +38,35 @@ var ErrNotLoggedIn = errors.New("not logged in")
// but direct media is not enabled.
var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled")
var ErrPortalIsDeleted = errors.New("portal is deleted")
var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event")
// Common message status errors
var (
ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false)
ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false)
ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true)
ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage()
ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage()
ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage()
ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage()
ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage()
ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage()
ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage()
ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage()
ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage()
ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
)
// Common login interface errors

View file

@ -9,15 +9,12 @@ package bridgev2
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"maps"
"net/http"
"slices"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exmime"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@ -137,11 +134,10 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32
}
type UserInfo struct {
Identifiers []string
Name *string
Avatar *Avatar
IsBot *bool
ExtraProfile database.ExtraProfile
Identifiers []string
Name *string
Avatar *Avatar
IsBot *bool
ExtraUpdates ExtraUpdater[*Ghost]
}
@ -189,9 +185,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
return true
}
func (ghost *Ghost) getExtraProfileMeta() any {
func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
bridgeName := ghost.Bridge.Network.GetName()
baseExtra := &event.BeeperProfileExtra{
return &event.BeeperProfileExtra{
RemoteID: string(ghost.ID),
Identifiers: ghost.Identifiers,
Service: bridgeName.BeeperBridgeType,
@ -199,35 +195,23 @@ func (ghost *Ghost) getExtraProfileMeta() any {
IsBridgeBot: false,
IsNetworkBot: ghost.IsBot,
}
if len(ghost.ExtraProfile) == 0 {
return baseExtra
}
mergedExtra := maps.Clone(ghost.ExtraProfile)
baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra))
exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra))
return mergedExtra
}
func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool {
if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta {
ghost.ContactInfoSet = false
return false
}
func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool {
if identifiers != nil {
slices.Sort(identifiers)
}
changed := extraProfile.CopyTo(&ghost.ExtraProfile)
if ghost.ContactInfoSet &&
(identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) &&
(isBot == nil || *isBot == ghost.IsBot) {
return false
}
if identifiers != nil {
changed = changed || !slices.Equal(identifiers, ghost.Identifiers)
ghost.Identifiers = identifiers
}
if isBot != nil {
changed = changed || *isBot != ghost.IsBot
ghost.IsBot = *isBot
}
if ghost.ContactInfoSet && !changed {
return false
}
err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta())
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata")
@ -250,7 +234,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool {
}
func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) {
if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
return
}
info, err := source.Client.GetUserInfo(ctx, ghost)
@ -260,16 +244,12 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin
zerolog.Ctx(ctx).Debug().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
Bool("has_avatar", ghost.AvatarMXC != "").
Bool("avatar_set", ghost.AvatarSet).
Msg("Updating ghost info in IfNecessary call")
ghost.UpdateInfo(ctx, info)
} else {
zerolog.Ctx(ctx).Trace().
Bool("has_name", ghost.Name != "").
Bool("name_set", ghost.NameSet).
Bool("has_avatar", ghost.AvatarMXC != "").
Bool("avatar_set", ghost.AvatarSet).
Msg("No ghost info received in IfNecessary call")
}
}
@ -297,14 +277,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) {
}
if info.Avatar != nil {
update = ghost.UpdateAvatar(ctx, info.Avatar) || update
} else if oldAvatar == "" && !ghost.AvatarSet {
// Special case: nil avatar means we're not expecting one ever, if we don't currently have
// one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary.
ghost.AvatarSet = true
update = true
}
if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil {
update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update
if info.Identifiers != nil || info.IsBot != nil {
update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update
}
if info.ExtraUpdates != nil {
update = info.ExtraUpdates(ctx, ghost) || update

View file

@ -13,7 +13,6 @@ import (
"strings"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
)
// LoginProcess represents a single occurrence of a user logging into the remote network.
@ -179,8 +178,6 @@ const (
LoginInputFieldTypeToken LoginInputFieldType = "token"
LoginInputFieldTypeURL LoginInputFieldType = "url"
LoginInputFieldTypeDomain LoginInputFieldType = "domain"
LoginInputFieldTypeSelect LoginInputFieldType = "select"
LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code"
)
type LoginInputDataField struct {
@ -192,13 +189,8 @@ type LoginInputDataField struct {
Name string `json:"name"`
// The description of the field shown to the user.
Description string `json:"description"`
// A default value that the client can pre-fill the field with.
DefaultValue string `json:"default_value,omitempty"`
// A regex pattern that the client can use to validate input client-side.
Pattern string `json:"pattern,omitempty"`
// For fields of type select, the valid options.
// Pattern may also be filled with a regex that matches the same options.
Options []string `json:"options,omitempty"`
// A function that validates the input and optionally cleans it up before it's submitted to the connector.
Validate func(string) (string, error) `json:"-"`
}
@ -273,23 +265,6 @@ func (f *LoginInputDataField) FillDefaultValidate() {
type LoginUserInputParams struct {
// The fields that the user needs to fill in.
Fields []LoginInputDataField `json:"fields"`
// Attachments to display alongside the input fields.
Attachments []*LoginUserInputAttachment `json:"attachments"`
}
type LoginUserInputAttachment struct {
Type event.MessageType `json:"type,omitempty"`
FileName string `json:"filename,omitempty"`
Content []byte `json:"content,omitempty"`
Info LoginUserInputAttachmentInfo `json:"info,omitempty"`
}
type LoginUserInputAttachmentInfo struct {
MimeType string `json:"mimetype,omitempty"`
Width int `json:"w,omitempty"`
Height int `json:"h,omitempty"`
Size int `json:"size,omitempty"`
}
type LoginCompleteParams struct {

View file

@ -26,7 +26,6 @@ import (
_ "go.mau.fi/util/dbutil/litestream"
"go.mau.fi/util/exbytes"
"go.mau.fi/util/exsync"
"go.mau.fi/util/ptr"
"go.mau.fi/util/random"
"golang.org/x/sync/semaphore"
@ -81,8 +80,6 @@ type Connector struct {
MediaConfig mautrix.RespMediaConfig
SpecVersions *mautrix.RespVersions
SpecCaps *mautrix.RespCapabilities
specCapsLock sync.Mutex
Capabilities *bridgev2.MatrixCapabilities
IgnoreUnsupportedServer bool
@ -144,20 +141,14 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) {
br.EventProcessor.On(event.EventReaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent)
br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent)
br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent)
br.EventProcessor.On(event.StateMember, br.handleRoomEvent)
br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent)
br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent)
br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent)
br.EventProcessor.On(event.StateTopic, br.handleRoomEvent)
br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent)
br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent)
br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent)
br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent)
br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent)
br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent)
br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent)
br.Bot = br.AS.BotIntent()
br.Crypto = NewCryptoHelper(br)
br.Bridge.Commands.(*commands.Processor).AddHandlers(
@ -282,7 +273,7 @@ func (br *Connector) GetPublicAddress() string {
if br.Config.AppService.PublicAddress == "https://bridge.example.com" {
return ""
}
return strings.TrimRight(br.Config.AppService.PublicAddress, "/")
return br.Config.AppService.PublicAddress
}
func (br *Connector) GetRouter() *http.ServeMux {
@ -344,18 +335,16 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) {
}
func (br *Connector) ensureConnection(ctx context.Context) {
triedToRegister := false
for {
versions, err := br.Bot.Versions(ctx)
if err != nil {
if errors.Is(err, mautrix.MForbidden) && !triedToRegister {
if errors.Is(err, mautrix.MForbidden) {
br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying")
err = br.Bot.EnsureRegistered(ctx)
if err != nil {
br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN")
os.Exit(16)
}
triedToRegister = true
} else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) {
br.logInitialRequestError(err, "/versions request failed with auth error")
os.Exit(16)
@ -368,9 +357,6 @@ func (br *Connector) ensureConnection(ctx context.Context) {
*br.AS.SpecVersions = *versions
br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites)
br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending)
br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange)
br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) ||
(br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo)
break
}
}
@ -415,21 +401,6 @@ func (br *Connector) ensureConnection(ctx context.Context) {
br.Bot.EnsureAppserviceConnection(ctx)
}
func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities {
br.specCapsLock.Lock()
defer br.specCapsLock.Unlock()
if br.SpecCaps != nil {
return br.SpecCaps
}
caps, err := br.Bot.Capabilities(ctx)
if err != nil {
br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver")
return nil
}
br.SpecCaps = caps
return caps
}
func (br *Connector) fetchMediaConfig(ctx context.Context) {
cfg, err := br.Bot.GetMediaConfig(ctx)
if err != nil {
@ -538,8 +509,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
Msg("Failed to send MSS event")
}
}
if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice &&
(ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
content := ms.ToNoticeEvent(evt)
if editEvent != "" {
content.SetEdit(editEvent)
@ -623,28 +593,13 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve
}
func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) {
if stateKey == "" {
switch eventType {
case event.StateCreate:
createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID)
if err != nil || createEvt != nil {
return createEvt, err
}
case event.StateJoinRules:
joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID)
if err != nil {
return nil, err
} else if joinRulesContent != nil {
return &event.Event{
Type: event.StateJoinRules,
RoomID: roomID,
StateKey: ptr.Ptr(""),
Content: event.Content{Parsed: joinRulesContent},
}, nil
}
if eventType == event.StateCreate && stateKey == "" {
createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID)
if err != nil || createEvt != nil {
return createEvt, err
}
}
return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey)
return br.Bot.FullStateEvent(ctx, roomID, eventType, "")
}
func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
@ -687,7 +642,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr
if intent != nil {
intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp)
}
if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction {
if evt.Type != event.EventEncrypted {
err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content)
if err != nil {
return nil, err

View file

@ -24,7 +24,6 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
@ -38,9 +37,9 @@ func init() {
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
var NoSessionFound = crypto.ErrNoSessionFound
var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
var UnknownMessageIndex = olm.ErrUnknownMessageIndex
var NoSessionFound = crypto.NoSessionFound
var DuplicateMessageIndex = crypto.DuplicateMessageIndex
var UnknownMessageIndex = olm.UnknownMessageIndex
type CryptoHelper struct {
bridge *Connector
@ -136,19 +135,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
return err
}
if isExistingDevice {
if !helper.verifyKeysAreOnServer(ctx) {
return nil
}
} else {
err = helper.ShareKeys(ctx)
if err != nil {
return fmt.Errorf("failed to share device keys: %w", err)
}
}
if helper.bridge.Config.Encryption.SelfSign {
if !helper.doSelfSign(ctx) {
os.Exit(34)
}
helper.verifyKeysAreOnServer(ctx)
}
go helper.resyncEncryptionInfo(context.TODO())
@ -156,46 +143,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
return nil
}
func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool {
log := zerolog.Ctx(ctx)
hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx)
if err != nil {
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status")
return false
}
log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status")
keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey)
if !hasKeys || keyInDB == "overwrite" {
if keyInDB != "" && keyInDB != "overwrite" {
log.WithLevel(zerolog.FatalLevel).
Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.")
return false
}
recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx)
if recoveryKey != "" {
helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey)
}
if err != nil {
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign")
return false
}
log.Info().Msg("Generated new recovery key and self-signed bot device")
} else if !isVerified {
if keyInDB == "" {
log.WithLevel(zerolog.FatalLevel).
Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.")
return false
}
err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB)
if err != nil {
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key")
return false
}
log.Info().Msg("Verified bot device with existing recovery key")
}
return true
}
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
@ -210,12 +157,12 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
var evt event.EncryptionEventContent
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
if err != nil {
log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event")
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
_, err = helper.store.DB.Exec(ctx, `
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
`, roomID)
if err != nil {
log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync")
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync")
}
} else {
maxAge := evt.RotationPeriodMillis
@ -238,9 +185,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
`, maxAge, maxMessages, roomID)
if err != nil {
log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table")
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table")
} else {
log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table")
log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table")
}
}
}
@ -286,7 +233,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
if err != nil {
return nil, false, fmt.Errorf("failed to find existing device ID: %w", err)
} else if len(deviceID) > 0 {
helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database")
helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
}
// Create a new client instance with the default AS settings (including as_token),
// the Login call will then override the access token in the client.
@ -327,7 +274,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
return client, deviceID != "", nil
}
func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
helper.log.Debug().Msg("Making sure keys are still on server")
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
@ -340,11 +287,10 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
}
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
if ok && len(device.Keys) > 0 {
return true
return
}
helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto")
helper.Reset(ctx, false)
return false
}
func (helper *CryptoHelper) Start() {
@ -439,7 +385,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
var encrypted *event.EncryptedEventContent
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
return
}
helper.log.Debug().Err(err).

View file

@ -9,7 +9,6 @@ package matrix
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -28,7 +27,6 @@ import (
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
@ -45,13 +43,13 @@ type ASIntent struct {
var _ bridgev2.MatrixAPI = (*ASIntent)(nil)
var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil)
var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil)
func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) {
if extra == nil {
extra = &bridgev2.MatrixSendExtra{}
}
if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) {
// TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions
if eventType == event.EventRedaction {
parsedContent := content.Parsed.(*event.RedactionEventContent)
as.Matrix.AddDoublePuppetValue(content)
return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{
@ -59,7 +57,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
Extra: content.Raw,
})
}
if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction {
if eventType != event.EventReaction && eventType != event.EventRedaction {
msgContent, ok := content.Parsed.(*event.MessageEventContent)
if ok {
msgContent.AddPerMessageProfileFallback()
@ -84,27 +82,16 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
eventType = event.EventEncrypted
}
}
return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()})
}
func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) {
if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
if extra.Timestamp.IsZero() {
return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content)
} else {
return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli())
}
if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
return nil, fmt.Errorf("failed to check if room is encrypted: %w", err)
} else if encrypted && as.Connector.Crypto != nil {
if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil {
return nil, err
}
eventType = event.EventEncrypted
}
return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID})
}
func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) {
targetContent, ok := content.Parsed.(*event.MemberEventContent)
if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" {
targetContent := content.Parsed.(*event.MemberEventContent)
if targetContent.Displayname != "" || targetContent.AvatarURL != "" {
return
}
memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID)
@ -139,7 +126,11 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e
if eventType == event.StateMember {
as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content)
}
resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()})
if ts.IsZero() {
resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content)
} else {
resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli())
}
if err != nil && eventType == event.StateMember {
var httpErr mautrix.HTTPError
if errors.As(err, &httpErr) && httpErr.RespError != nil &&
@ -421,7 +412,6 @@ func (as *ASIntent) UploadMediaStream(
removeAndClose(replFile)
removeAndClose(tempFile)
}
req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
startedAsyncUpload = true
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
@ -454,7 +444,6 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn
as.Connector.uploadSema.Release(int64(len(req.ContentBytes)))
}
}
req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
var resp *mautrix.RespCreateMXC
resp, err = as.Matrix.UploadAsync(ctx, req)
if resp != nil {
@ -486,62 +475,11 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr
return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL)
}
func dataToFields(data any) (map[string]json.RawMessage, error) {
fields, ok := data.(map[string]json.RawMessage)
if ok {
return fields, nil
}
d, err := json.Marshal(data)
if err != nil {
return nil, err
}
d = canonicaljson.CanonicalJSONAssumeValid(d)
err = json.Unmarshal(d, &fields)
return fields, err
}
func marshalField(val any) json.RawMessage {
data, _ := json.Marshal(val)
if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
return canonicaljson.CanonicalJSONAssumeValid(data)
}
return data
}
var nullJSON = json.RawMessage("null")
func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
return as.Matrix.BeeperUpdateProfile(ctx, data)
} else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo {
fields, err := dataToFields(data)
if err != nil {
return fmt.Errorf("failed to marshal fields: %w", err)
}
currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID)
if err != nil {
return fmt.Errorf("failed to get current profile: %w", err)
}
for key, val := range fields {
existing, ok := currentProfile.Extra[key]
if !ok {
if bytes.Equal(val, nullJSON) {
continue
}
err = as.Matrix.SetProfileField(ctx, key, val)
} else if !bytes.Equal(marshalField(existing), val) {
if bytes.Equal(val, nullJSON) {
err = as.Matrix.DeleteProfileField(ctx, key)
} else {
err = as.Matrix.SetProfileField(ctx, key, val)
}
}
if err != nil {
return fmt.Errorf("failed to set profile field %q: %w", key, err)
}
}
if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
return nil
}
return nil
return as.Matrix.BeeperUpdateProfile(ctx, data)
}
func (as *ASIntent) GetMXID() id.UserID {
@ -583,39 +521,6 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent {
return content
}
func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) {
if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
// Hungryserv doesn't override the capabilities endpoint nor do room versions
return
}
caps := as.Connector.fetchCapabilities(ctx)
roomVer := req.RoomVersion
if roomVer == "" && caps != nil && caps.RoomVersions != nil {
roomVer = id.RoomVersion(caps.RoomVersions.Default)
}
if roomVer != "" && !roomVer.PrivilegedRoomCreators() {
return
}
creators, _ := req.CreationContent["additional_creators"].([]id.UserID)
creators = append(slices.Clone(creators), as.GetMXID())
if req.PowerLevelOverride != nil {
for _, creator := range creators {
delete(req.PowerLevelOverride.Users, creator)
}
}
for _, evt := range req.InitialState {
if evt.Type != event.StatePowerLevels {
continue
}
content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
if ok {
for _, creator := range creators {
delete(content.Users, creator)
}
}
}
}
func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) {
if as.Connector.Config.Encryption.Default {
req.InitialState = append(req.InitialState, &event.Event{
@ -631,7 +536,6 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom)
}
req.CreationContent["m.federate"] = false
}
as.filterCreateRequestForV12(ctx, req)
resp, err := as.Matrix.CreateRoom(ctx, req)
if err != nil {
return "", err
@ -673,9 +577,6 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id.
}
func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error {
if roomID == "" {
return nil
}
if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) {
err := as.Matrix.BeeperDeleteRoom(ctx, roomID)
if err != nil {
@ -773,23 +674,3 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T
})
}
}
func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) {
evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID)
if err != nil {
return nil, err
}
err = evt.Content.ParseRaw(evt.Type)
if err != nil {
zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content")
}
if evt.Type == event.EventEncrypted {
if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt {
return nil, errors.New("can't decrypt the event")
}
return as.Connector.Crypto.Decrypt(ctx, evt)
}
return evt, nil
}

View file

@ -27,11 +27,6 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) {
if br.shouldIgnoreEvent(evt) {
return
}
if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember {
zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events")
br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
return
}
if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require {
zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required")
br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true)
@ -68,10 +63,6 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event)
case event.EphemeralEventTyping:
typingContent := evt.Content.AsTyping()
typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser)
case event.BeeperEphemeralEventAIStream:
if br.shouldIgnoreEvent(evt) {
return
}
}
br.Bridge.QueueMatrixEvent(ctx, evt)
}
@ -85,11 +76,6 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents {
log.Debug().Msg("Dropping event from user with no permission to send events")
br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
return
}
ctx = log.WithContext(ctx)
if br.Crypto == nil {
br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true)
@ -101,18 +87,17 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
decryptionStart := time.Now()
decrypted, err := br.Crypto.Decrypt(ctx, evt)
decryptionRetryCount := 0
var errorEventID id.EventID
if errors.Is(err, NoSessionFound) {
decryptionRetryCount = 1
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false)
go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false)
if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = br.Crypto.Decrypt(ctx, evt)
} else {
go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID)
go br.waitLongerForSession(ctx, evt, decryptionStart)
return
}
}
@ -121,18 +106,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true)
return
}
br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart))
br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart))
}
func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) {
func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) {
log := zerolog.Ctx(ctx)
content := evt.Content.AsEncrypted()
log.Debug().
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, requesting keys and waiting longer...")
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
var errorEventID *id.EventID
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)
if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@ -235,6 +220,7 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event
go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount)
decrypted.Mautrix.CheckpointSent = true
decrypted.Mautrix.DecryptionDuration = duration
decrypted.Mautrix.EventSource |= event.SourceDecrypted
br.EventProcessor.Dispatch(ctx, decrypted)
if errorEventID != nil && *errorEventID != "" {
_, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID)

View file

@ -66,12 +66,7 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s
} else if errors.Is(err, dbutil.ErrForeignTables) {
br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info")
} else if errors.Is(err, dbutil.ErrNotOwned) {
var noe dbutil.NotOwnedError
if errors.As(err, &noe) && noe.Owner == br.Name {
br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?")
} else {
br.Log.Info().Msg("Sharing the same database with different programs is not supported")
}
br.Log.Info().Msg("Sharing the same database with different programs is not supported")
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
br.Log.Info().Msg("Downgrading the bridge is not supported")
}

View file

@ -1,161 +0,0 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package mxmain
import (
"fmt"
"iter"
"os"
"reflect"
"strconv"
"strings"
"go.mau.fi/util/random"
)
var randomParseFilePrefix = random.String(16) + "READFILE:"
func parseEnv(prefix string) iter.Seq2[[]string, string] {
return func(yield func([]string, string) bool) {
for _, s := range os.Environ() {
if !strings.HasPrefix(s, prefix) {
continue
}
kv := strings.SplitN(s, "=", 2)
key := strings.TrimPrefix(kv[0], prefix)
value := kv[1]
if strings.HasSuffix(key, "_FILE") {
key = strings.TrimSuffix(key, "_FILE")
value = randomParseFilePrefix + value
}
key = strings.ToLower(key)
if !strings.ContainsRune(key, '.') {
key = strings.ReplaceAll(key, "__", ".")
}
if !yield(strings.Split(key, "."), value) {
return
}
}
}
}
func reflectYAMLFieldName(f *reflect.StructField) string {
parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2)
fieldName := parts[0]
if fieldName == "-" && len(parts) == 1 {
return ""
}
if fieldName == "" {
return strings.ToLower(f.Name)
}
return fieldName
}
type reflectGetResult struct {
val reflect.Value
valKind reflect.Kind
remainingPath []string
}
func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) {
if len(path) == 0 {
return &reflectGetResult{val: rv, valKind: rv.Kind()}, true
}
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
switch rv.Kind() {
case reflect.Map:
return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true
case reflect.Struct:
fields := reflect.VisibleFields(rv.Type())
for _, field := range fields {
fieldName := reflectYAMLFieldName(&field)
if fieldName != "" && fieldName == path[0] {
return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:])
}
}
default:
}
return nil, false
}
func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) {
if len(path) > 0 && path[0] == "network" {
return reflectGetYAML(network, path[1:])
}
return reflectGetYAML(main, path)
}
func formatKeyString(key []string) string {
return strings.Join(key, "->")
}
func UpdateConfigFromEnv(cfg, networkData any, prefix string) error {
cfgVal := reflect.ValueOf(cfg)
networkVal := reflect.ValueOf(networkData)
for key, value := range parseEnv(prefix) {
field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key)
if !ok {
return fmt.Errorf("%s not found", formatKeyString(key))
}
if strings.HasPrefix(value, randomParseFilePrefix) {
filepath := strings.TrimPrefix(value, randomParseFilePrefix)
fileData, err := os.ReadFile(filepath)
if err != nil {
return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err)
}
value = strings.TrimSpace(string(fileData))
}
var parsedVal any
var err error
switch field.valKind {
case reflect.String:
parsedVal = value
case reflect.Bool:
parsedVal, err = strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
parsedVal, err = strconv.ParseInt(value, 10, 64)
if err != nil {
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
parsedVal, err = strconv.ParseUint(value, 10, 64)
if err != nil {
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
}
case reflect.Float32, reflect.Float64:
parsedVal, err = strconv.ParseFloat(value, 64)
if err != nil {
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
}
default:
return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key))
}
if field.val.Kind() == reflect.Ptr {
if field.val.IsNil() {
field.val.Set(reflect.New(field.val.Type().Elem()))
}
field.val = field.val.Elem()
}
if field.val.Kind() == reflect.Map {
key = key[:len(key)-len(field.remainingPath)]
mapKeyStr := strings.Join(field.remainingPath, ".")
key = append(key, mapKeyStr)
if field.val.Type().Key().Kind() != reflect.String {
return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key))
}
field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal))
} else {
field.val.Set(reflect.ValueOf(parsedVal))
}
}
return nil
}

View file

@ -15,7 +15,6 @@ bridge:
# By default, users who are in the same group on the remote network will be
# in the same Matrix room bridged to that group. If this is set to true,
# every user will get their own Matrix room instead.
# SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST.
split_portals: false
# Should the bridge resend `m.bridge` events to all portals on startup?
resend_bridge_info: false
@ -29,9 +28,6 @@ bridge:
# How long after an unknown error should the bridge attempt a full reconnect?
# Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value.
unknown_error_auto_reconnect: null
# Maximum number of times to do the auto-reconnect above.
# The counter is per login, but is never reset except on logout and restart.
unknown_error_max_auto_reconnects: 10
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
bridge_matrix_leave: false
@ -50,11 +46,6 @@ bridge:
# Should cross-room reply metadata be bridged?
# Most Matrix clients don't support this and servers may reject such messages too.
cross_room_replies: false
# If a state event fails to bridge, should the bridge revert any state changes made by that event?
revert_failed_state_changes: false
# In portals with no relay set, should Matrix users be kicked if they're
# not logged into an account that's in the remote chat?
kick_matrix_users: true
# What should be done to portal rooms when a user logs out or is logged out?
# Permitted values:
@ -244,9 +235,6 @@ matrix:
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
# rather than keeping the whole file in memory.
upload_file_threshold: 5242880
# Should the bridge set additional custom profile info for ghosts?
# This can make a lot of requests, as there's no batch profile update endpoint.
ghost_extra_profile_info: false
# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors.
analytics:
@ -259,8 +247,10 @@ analytics:
# Settings for provisioning API
provisioning:
# Prefix for the provisioning API paths.
prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate" or null, a random secret will be generated,
# or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters.
# or if set to "disable", the provisioning API will be disabled.
shared_secret: generate
# Whether to allow provisioning API requests to be authed using Matrix access tokens.
# This follows the same rules as double puppeting to determine which server to contact to check the token,
@ -286,14 +276,6 @@ public_media:
expiry: 0
# Length of hash to use for public media URLs. Must be between 0 and 32.
hash_length: 32
# The path prefix for generated URLs. Note that this will NOT change the path where media is actually served.
# If you change this, you must configure your reverse proxy to rewrite the path accordingly.
path_prefix: /_mautrix/publicmedia
# Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs?
# If false, the generated URLs will just have the MXC URI and a HMAC signature.
# The hash_length field will be used to decide the length of the generated URL.
# This also allows invalidating URLs by deleting the database entry.
use_database: false
# Settings for converting remote media to custom mxc:// URIs instead of reuploading.
# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html
@ -384,12 +366,6 @@ encryption:
# Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861).
# Changing this option requires updating the appservice registration file.
msc4190: false
# Whether to encrypt reactions and reply metadata as per MSC4392.
msc4392: false
# Should the bridge bot generate a recovery key and cross-signing keys and verify itself?
# Note that without the latest version of MSC4190, this will fail if you reset the bridge database.
# The generated recovery key will be saved in the kv_store table under `recovery_key`.
self_sign: false
# Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
# You must use a client that supports requesting keys from other users to use this feature.
allow_key_sharing: true
@ -452,16 +428,6 @@ encryption:
# You should not enable this option unless you understand all the implications.
disable_device_change_key_rotation: false
# Prefix for environment variables. All variables with this prefix must map to valid config fields.
# Nesting in variable names is represented with a dot (.).
# If there are no dots in the name, two underscores (__) are replaced with a dot.
#
# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token.
# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily.
#
# If this is null, reading config fields from environment will be disabled.
env_config_prefix: null
# Logging config. See https://github.com/tulir/zeroconfig for details.
logging:
min_level: debug

View file

@ -135,10 +135,7 @@ func (br *BridgeMain) CheckLegacyDB(
}
var dbVersion int
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
if err != nil {
log.Fatal().Err(err).Msg("Failed to get database version")
return
} else if dbVersion < expectedVersion {
if dbVersion < expectedVersion {
log.Fatal().
Int("expected_version", expectedVersion).
Int("version", dbVersion).

View file

@ -26,7 +26,6 @@ import (
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
"go.mau.fi/util/progver"
"gopkg.in/yaml.v3"
flag "maunium.net/go/mauflag"
@ -63,9 +62,6 @@ type BridgeMain struct {
// git tag to see if the built version is the release or a dev build.
// You can either bump this right after a release or right before, as long as it matches on the release commit.
Version string
// SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning,
// such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch.
SemCalVer bool
// PostInit is a function that will be called after the bridge has been initialized but before it is started.
PostInit func()
@ -90,7 +86,11 @@ type BridgeMain struct {
RegistrationPath string
SaveConfig bool
ver progver.ProgramVersion
baseVersion string
commit string
LinkifiedVersion string
VersionDesc string
BuildTime time.Time
AdditionalShortFlags string
AdditionalLongFlags string
@ -99,7 +99,14 @@ type BridgeMain struct {
}
type VersionJSONOutput struct {
progver.ProgramVersion
Name string
URL string
Version string
IsRelease bool
Commit string
FormattedVersion string
BuildTime time.Time
OS string
Arch string
@ -140,11 +147,18 @@ func (br *BridgeMain) PreInit() {
flag.PrintHelp()
os.Exit(0)
} else if *version {
fmt.Println(br.ver.VersionDescription)
fmt.Println(br.VersionDesc)
os.Exit(0)
} else if *versionJSON {
output := VersionJSONOutput{
ProgramVersion: br.ver,
URL: br.URL,
Name: br.Name,
Version: br.baseVersion,
IsRelease: br.Version == br.baseVersion,
Commit: br.commit,
FormattedVersion: br.Version,
BuildTime: br.BuildTime,
OS: runtime.GOOS,
Arch: runtime.GOARCH,
@ -226,8 +240,8 @@ func (br *BridgeMain) Init() {
br.Log.Info().
Str("name", br.Name).
Str("version", br.ver.FormattedVersion).
Time("built_at", br.ver.BuildTime).
Str("version", br.Version).
Time("built_at", br.BuildTime).
Str("go_version", runtime.Version()).
Msg("Initializing bridge")
@ -241,7 +255,7 @@ func (br *BridgeMain) Init() {
br.Matrix.AS.DoublePuppetValue = br.Name
br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{
Func: func(ce *commands.Event) {
ce.Reply(br.ver.MarkdownDescription())
ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123))
},
Name: "version",
Help: commands.HelpMeta{
@ -354,13 +368,6 @@ func (br *BridgeMain) LoadConfig() {
}
}
cfg.Bridge.Backfill = cfg.Backfill
if cfg.EnvConfigPrefix != "" {
err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err)
os.Exit(10)
}
}
br.Config = &cfg
}
@ -439,12 +446,42 @@ func (br *BridgeMain) Stop() {
//
// (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`)
func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) {
br.ver = progver.ProgramVersion{
Name: br.Name,
URL: br.URL,
BaseVersion: br.Version,
SemCalVer: br.SemCalVer,
}.Init(tag, commit, rawBuildTime)
mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent)
br.Version = br.ver.FormattedVersion
br.baseVersion = br.Version
if len(tag) > 0 && tag[0] == 'v' {
tag = tag[1:]
}
if tag != br.Version {
suffix := ""
if !strings.HasSuffix(br.Version, "+dev") {
suffix = "+dev"
}
if len(commit) > 8 {
br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8])
} else {
br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix)
}
}
br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version)
if tag == br.Version {
br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag)
} else if len(commit) > 8 {
br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1)
}
var buildTime time.Time
if rawBuildTime != "unknown" {
buildTime, _ = time.Parse(time.RFC3339, rawBuildTime)
}
var builtWith string
if buildTime.IsZero() {
rawBuildTime = "unknown"
builtWith = runtime.Version()
} else {
rawBuildTime = buildTime.Format(time.RFC1123)
builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version())
}
mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent)
br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith)
br.commit = commit
br.BuildTime = buildTime
}

View file

@ -30,7 +30,6 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/bridgev2/provisionutil"
"maunium.net/go/mautrix/bridgev2/status"
"maunium.net/go/mautrix/federation"
"maunium.net/go/mautrix/id"
@ -85,9 +84,10 @@ const (
provisioningUserKey provisioningContextKey = iota
provisioningUserLoginKey
provisioningLoginProcessKey
ProvisioningKeyRequest
)
const ProvisioningKeyRequest = "fi.mau.provision.request"
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
}
@ -96,7 +96,12 @@ func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
return prov.Router
}
func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI {
type IProvisioningAPI interface {
GetRouter() *http.ServeMux
GetUser(r *http.Request) *bridgev2.User
}
func (br *Connector) GetProvisioning() IProvisioningAPI {
return br.Provisioning
}
@ -114,7 +119,6 @@ func (prov *ProvisioningAPI) Init() {
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
prov.Router = http.NewServeMux()
prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami)
prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities)
prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows)
prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep)
@ -124,7 +128,7 @@ func (prov *ProvisioningAPI) Init() {
prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup)
prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
if prov.br.Config.Provisioning.EnableSessionTransfers {
prov.log.Debug().Msg("Enabling session transfer API")
@ -206,20 +210,12 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI
}
}
func disabledAuth(w http.ResponseWriter, r *http.Request) {
mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w)
}
func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
secret := prov.br.Config.Provisioning.SharedSecret
if len(secret) < 16 {
return http.HandlerFunc(disabledAuth)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" {
mautrix.MMissingToken.WithMessage("Missing auth token").Write(w)
} else if !exstrings.ConstantTimeEqual(auth, secret) {
} else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w)
} else {
h.ServeHTTP(w, r)
@ -228,10 +224,6 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
}
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
secret := prov.br.Config.Provisioning.SharedSecret
if len(secret) < 16 {
return http.HandlerFunc(disabledAuth)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth == "" && prov.GetAuthFromRequest != nil {
@ -245,7 +237,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
if userID == "" && prov.GetUserIDFromRequest != nil {
userID = prov.GetUserIDFromRequest(r)
}
if !exstrings.ConstantTimeEqual(auth, secret) {
if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
var err error
if strings.HasPrefix(auth, "openid:") {
err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:"))
@ -324,7 +316,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) {
prevState.UserID = ""
prevState.RemoteID = ""
prevState.RemoteName = ""
prevState.RemoteProfile = status.RemoteProfile{}
prevState.RemoteProfile = nil
resp.Logins[i] = RespWhoamiLogin{
StateEvent: prevState.StateEvent,
StateTS: prevState.Timestamp,
@ -356,24 +348,18 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques
})
}
func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) {
exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning)
}
var ErrNilStep = errors.New("bridge returned nil step with no error")
var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"}
func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) {
overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r)
if failed {
return
}
user := prov.GetUser(r)
if overrideLogin == nil && user.HasTooManyLogins() {
ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w)
return
}
login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID"))
login, err := prov.net.CreateLogin(
r.Context(),
prov.GetUser(r),
r.PathValue("flowID"),
)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
RespondWithError(w, err, "Internal error creating login process")
@ -403,18 +389,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
Override: overrideLogin,
}
prov.loginsLock.Unlock()
zerolog.Ctx(r.Context()).Info().
Any("first_step", firstStep).
Msg("Created login process")
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
}
func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) {
zerolog.Ctx(ctx).Info().
Str("step_id", step.StepID).
Str("user_login_id", string(step.CompleteParams.UserLoginID)).
Msg("Login completed successfully")
prov.deleteLogin(login, false)
if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID {
return
}
@ -428,15 +406,6 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
}, bridgev2.DeleteOpts{LogoutRemote: true})
}
func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) {
if cancel {
login.Process.Cancel()
}
prov.loginsLock.Lock()
delete(prov.logins, login.ID)
prov.loginsLock.Unlock()
}
func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) {
loginID := r.PathValue("loginProcessID")
prov.loginsLock.RLock()
@ -507,14 +476,11 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input")
RespondWithError(w, err, "Internal error submitting input")
prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
} else {
zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
@ -528,14 +494,11 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait")
RespondWithError(w, err, "Internal error waiting for login")
prov.deleteLogin(login, true)
return
}
login.NextStep = nextStep
if nextStep.Type == bridgev2.LoginStepTypeComplete {
prov.handleCompleteStep(r.Context(), login, nextStep)
} else {
zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
}
@ -619,23 +582,115 @@ func RespondWithError(w http.ResponseWriter, err error, message string) {
}
}
type RespResolveIdentifier struct {
ID networkid.UserID `json:"id"`
Name string `json:"name,omitempty"`
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
Identifiers []string `json:"identifiers,omitempty"`
MXID id.UserID `json:"mxid,omitempty"`
DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
}
func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) {
login := prov.GetLoginForRequest(w, r)
if login == nil {
return
}
resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat)
api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
if !ok {
mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
return
}
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
RespondWithError(w, err, "Internal error resolving identifier")
return
} else if resp == nil {
mautrix.MNotFound.WithMessage("Identifier not found").Write(w)
} else {
status := http.StatusOK
if resp.JustCreated {
status = http.StatusCreated
}
exhttp.WriteJSONResponse(w, status, resp)
return
}
apiResp := &RespResolveIdentifier{
ID: resp.UserID,
}
status := http.StatusOK
if resp.Ghost != nil {
if resp.UserInfo != nil {
resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo)
}
apiResp.Name = resp.Ghost.Name
apiResp.AvatarURL = resp.Ghost.AvatarMXC
apiResp.Identifiers = resp.Ghost.Identifiers
apiResp.MXID = resp.Ghost.Intent.GetMXID()
} else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
apiResp.Name = *resp.UserInfo.Name
}
if resp.Chat != nil {
if resp.Chat.Portal == nil {
resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal")
mautrix.MUnknown.WithMessage("Failed to get portal").Write(w)
return
}
}
if createChat && resp.Chat.Portal.MXID == "" {
status = http.StatusCreated
err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room")
mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w)
return
}
}
apiResp.DMRoomID = resp.Chat.Portal.MXID
}
exhttp.WriteJSONResponse(w, status, apiResp)
}
type RespGetContactList struct {
Contacts []*RespResolveIdentifier `json:"contacts"`
}
func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) {
apiResp = make([]*RespResolveIdentifier, len(resp))
for i, contact := range resp {
apiContact := &RespResolveIdentifier{
ID: contact.UserID,
}
apiResp[i] = apiContact
if contact.UserInfo != nil {
if contact.UserInfo.Name != nil {
apiContact.Name = *contact.UserInfo.Name
}
if contact.UserInfo.Identifiers != nil {
apiContact.Identifiers = contact.UserInfo.Identifiers
}
}
if contact.Ghost != nil {
if contact.Ghost.Name != "" {
apiContact.Name = contact.Ghost.Name
}
if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
apiContact.Identifiers = contact.Ghost.Identifiers
}
apiContact.AvatarURL = contact.Ghost.AvatarMXC
apiContact.MXID = contact.Ghost.Intent.GetMXID()
}
if contact.Chat != nil {
if contact.Chat.Portal == nil {
var err error
contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
}
}
if contact.Chat.Portal != nil {
apiContact.DMRoomID = contact.Chat.Portal.MXID
}
}
}
return
}
func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) {
@ -643,18 +698,30 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque
if login == nil {
return
}
resp, err := provisionutil.GetContactList(r.Context(), login)
if err != nil {
RespondWithError(w, err, "Internal error getting contact list")
api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
if !ok {
mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w)
return
}
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
resp, err := api.GetContactList(r.Context())
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
RespondWithError(w, err, "Internal error fetching contact list")
return
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{
Contacts: prov.processResolveIdentifiers(r.Context(), resp),
})
}
type ReqSearchUsers struct {
Query string `json:"query"`
}
type RespSearchUsers struct {
Results []*RespResolveIdentifier `json:"results"`
}
func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) {
var req ReqSearchUsers
err := json.NewDecoder(r.Body).Decode(&req)
@ -667,12 +734,20 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
if login == nil {
return
}
resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query)
if err != nil {
RespondWithError(w, err, "Internal error searching users")
api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
if !ok {
mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w)
return
}
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
resp, err := api.SearchUsers(r.Context(), req.Query)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
RespondWithError(w, err, "Internal error fetching contact list")
return
}
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{
Results: prov.processResolveIdentifiers(r.Context(), resp),
})
}
func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) {
@ -684,24 +759,11 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request
}
func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) {
var req bridgev2.GroupCreateParams
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
return
}
req.Type = r.PathValue("type")
login := prov.GetLoginForRequest(w, r)
if login == nil {
return
}
resp, err := provisionutil.CreateGroup(r.Context(), login, &req)
if err != nil {
RespondWithError(w, err, "Internal error creating group")
return
}
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w)
}
type ReqExportCredentials struct {

View file

@ -361,25 +361,14 @@ paths:
$ref: '#/components/responses/InternalError'
501:
$ref: '#/components/responses/NotSupported'
/v3/create_group/{type}:
/v3/create_group:
post:
tags: [ snc ]
summary: Create a group chat on the remote network.
operationId: createGroup
parameters:
- $ref: "#/components/parameters/loginID"
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/GroupCreateParams'
responses:
200:
description: Identifier resolved successfully
content:
application/json:
schema:
$ref: '#/components/schemas/CreatedGroup'
401:
$ref: '#/components/responses/Unauthorized'
404:
@ -400,7 +389,7 @@ components:
- username
- meow@example.com
loginID:
name: login_id
name: loginID
in: query
description: An optional explicit login ID to do the action through.
required: false
@ -583,74 +572,6 @@ components:
description: The Matrix room ID of the direct chat with the user.
examples:
- '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
GroupCreateParams:
type: object
description: |
Parameters for creating a group chat.
The /capabilities endpoint response must be checked to see which fields are actually allowed.
properties:
type:
type: string
description: The type of group to create.
examples:
- channel
username:
type: string
description: The public username for the created group.
participants:
type: array
description: The users to add to the group initially.
items:
type: string
parent:
type: object
name:
type: object
description: The `m.room.name` event content for the room.
properties:
name:
type: string
avatar:
type: object
description: The `m.room.avatar` event content for the room.
properties:
url:
type: string
format: mxc
topic:
type: object
description: The `m.room.topic` event content for the room.
properties:
topic:
type: string
disappear:
type: object
description: The `com.beeper.disappearing_timer` event content for the room.
properties:
type:
type: string
timer:
type: number
room_id:
type: string
format: matrix_room_id
description: |
An existing Matrix room ID to bridge to.
The other parameters must be already in sync with the room state when using this parameter.
CreatedGroup:
type: object
description: A successfully created group chat.
required: [id, mxid]
properties:
id:
type: string
description: The internal chat ID of the created group.
mxid:
type: string
format: matrix_room_id
description: The Matrix room ID of the portal.
examples:
- '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
LoginStep:
type: object
description: A step in a login process.
@ -714,7 +635,7 @@ components:
type:
type: string
description: The type of field.
enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ]
enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ]
id:
type: string
description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge.
@ -728,53 +649,10 @@ components:
description: A more detailed description of the field shown to the user.
examples:
- Include the country code with a +
default_value:
type: string
description: A default value that the client can pre-fill the field with.
pattern:
type: string
format: regex
description: A regular expression that the field value must match.
options:
type: array
description: For fields of type select, the valid options.
items:
type: string
attachments:
type: array
description: A list of media attachments to show the user alongside the form fields.
items:
type: object
description: A media attachment to show the user.
required: [ type, filename, content ]
properties:
type:
type: string
description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported.
enum: [ m.image, m.audio ]
filename:
type: string
description: The filename for the media attachment.
content:
type: string
description: The raw file content for the attachment encoded in base64.
info:
type: object
description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted.
properties:
mimetype:
type: string
description: The MIME type for the media content.
examples: [ image/png, audio/mpeg ]
w:
type: number
description: The width of the media in pixels. Only applicable for images and videos.
h:
type: number
description: The height of the media in pixels. Only applicable for images and videos.
size:
type: number
description: The size of the media content in number of bytes. Strongly recommended to include.
- description: Cookie login step
required: [ type, cookies ]
properties:

View file

@ -7,26 +7,16 @@
package matrix
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -43,10 +33,7 @@ func (br *Connector) initPublicMedia() error {
return fmt.Errorf("public media hash length is negative")
}
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia)
return nil
}
@ -57,20 +44,6 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte {
return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)]
}
func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte {
hasher := hmac.New(sha256.New, br.pubMediaSigKey)
hasher.Write([]byte(pm.MXC.String()))
hasher.Write([]byte(pm.MimeType))
if pm.Keys != nil {
hasher.Write([]byte(pm.Keys.Version))
hasher.Write([]byte(pm.Keys.Key.Algorithm))
hasher.Write([]byte(pm.Keys.Key.Key))
hasher.Write([]byte(pm.Keys.InitVector))
hasher.Write([]byte(pm.Keys.Hashes.SHA256))
}
return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength]
}
func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte {
var expiresAt []byte
if br.Config.PublicMedia.Expiry > 0 {
@ -120,47 +93,9 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
http.Error(w, "checksum expired", http.StatusGone)
return
}
br.doProxyMedia(w, r, contentURI, nil, "")
}
func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) {
if !br.Config.PublicMedia.UseDatabase {
http.Error(w, "public media short links are disabled", http.StatusNotFound)
return
}
log := zerolog.Ctx(r.Context())
media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID"))
if err != nil {
log.Err(err).Msg("Failed to get public media from database")
http.Error(w, "failed to get media metadata", http.StatusInternalServerError)
return
} else if media == nil {
http.Error(w, "media ID not found", http.StatusNotFound)
return
} else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) {
// This is not gone as it can still be refreshed in the DB
http.Error(w, "media expired", http.StatusNotFound)
return
} else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil {
http.Error(w, "media keys are malformed", http.StatusInternalServerError)
return
}
br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType)
}
var safeMimes = []string{
"text/css", "text/plain", "text/csv",
"application/json", "application/ld+json",
"image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
"video/mp4", "video/webm", "video/ogg", "video/quicktime",
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
"audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac",
}
func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) {
resp, err := br.Bot.Download(r.Context(), contentURI)
if err != nil {
zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
http.Error(w, "failed to download media", http.StatusInternalServerError)
return
}
@ -168,41 +103,11 @@ func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, conten
for _, hdr := range proxyHeadersToCopy {
w.Header()[hdr] = resp.Header[hdr]
}
stream := resp.Body
if encInfo != nil {
if mimeType == "" {
mimeType = "application/octet-stream"
}
contentDisposition := "attachment"
if slices.Contains(safeMimes, mimeType) {
contentDisposition = "inline"
}
dispositionArgs := map[string]string{}
if filename := r.PathValue("filename"); filename != "" {
dispositionArgs["filename"] = filename
}
w.Header().Set("Content-Type", mimeType)
w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs))
// Note: this won't check the Close result like it should, but it's probably not a big deal here
stream = encInfo.DecryptStream(stream)
} else if filename := r.PathValue("filename"); filename != "" {
contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
if contentDisposition == "" {
contentDisposition = "attachment"
}
w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{
"filename": filename,
}))
}
w.WriteHeader(http.StatusOK)
_, _ = io.Copy(w, stream)
_, _ = io.Copy(w, resp.Body)
}
func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string {
return br.getPublicMediaAddressWithFileName(contentURI, "")
}
func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string {
if br.pubMediaSigKey == nil {
return ""
}
@ -210,69 +115,11 @@ func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIS
if err != nil || !parsed.IsValid() {
return ""
}
fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_"))
if fileName == ".." {
fileName = ""
}
parts := []string{
return fmt.Sprintf(
"%s/_mautrix/publicmedia/%s/%s/%s",
br.GetPublicAddress(),
strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
parsed.Homeserver,
parsed.FileID,
base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)),
fileName,
}
if fileName == "" {
parts = parts[:len(parts)-1]
}
return strings.Join(parts, "/")
}
func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) {
if br.pubMediaSigKey == nil {
return "", bridgev2.ErrPublicMediaDisabled
}
if !br.Config.PublicMedia.UseDatabase {
if evt.File != nil {
return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled)
}
return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil
}
mxc := evt.URL
var keys *attachment.EncryptedFile
if evt.File != nil {
mxc = evt.File.URL
keys = &evt.File.EncryptedFile
}
parsedMXC, err := mxc.Parse()
if err != nil {
return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
}
pm := &database.PublicMedia{
MXC: parsedMXC,
Keys: keys,
MimeType: evt.GetInfo().MimeType,
}
if br.Config.PublicMedia.Expiry > 0 {
pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second)
}
pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm))
err = br.Bridge.DB.PublicMedia.Put(ctx, pm)
if err != nil {
return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
}
fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_"))
if fileName == ".." {
fileName = ""
}
parts := []string{
br.GetPublicAddress(),
strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
pm.PublicID,
fileName,
}
if fileName == "" {
parts = parts[:len(parts)-1]
}
return strings.Join(parts, "/"), nil
)
}

View file

@ -14,8 +14,6 @@ import (
"os"
"time"
"go.mau.fi/util/exhttp"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
@ -25,10 +23,8 @@ import (
)
type MatrixCapabilities struct {
AutoJoinInvites bool
BatchSending bool
ArbitraryMemberChange bool
ExtraProfileMeta bool
AutoJoinInvites bool
BatchSending bool
}
type MatrixConnector interface {
@ -62,54 +58,35 @@ type MatrixConnector interface {
}
type MatrixConnectorWithArbitraryRoomState interface {
MatrixConnector
GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error)
}
type MatrixConnectorWithServer interface {
MatrixConnector
GetPublicAddress() string
GetRouter() *http.ServeMux
}
type IProvisioningAPI interface {
GetRouter() *http.ServeMux
GetUser(r *http.Request) *User
}
type MatrixConnectorWithProvisioning interface {
MatrixConnector
GetProvisioning() IProvisioningAPI
}
type MatrixConnectorWithPublicMedia interface {
MatrixConnector
GetPublicMediaAddress(contentURI id.ContentURIString) string
GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error)
}
type MatrixConnectorWithNameDisambiguation interface {
MatrixConnector
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
}
type MatrixConnectorWithBridgeIdentifier interface {
MatrixConnector
GetUniqueBridgeID() string
}
type MatrixConnectorWithURLPreviews interface {
MatrixConnector
GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error)
}
type MatrixConnectorWithPostRoomBridgeHandling interface {
MatrixConnector
HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error
}
type MatrixConnectorWithAnalytics interface {
MatrixConnector
TrackAnalytics(userID id.UserID, event string, properties map[string]any)
}
@ -124,15 +101,9 @@ type DirectNotificationData struct {
}
type MatrixConnectorWithNotifications interface {
MatrixConnector
DisplayNotification(ctx context.Context, data *DirectNotificationData)
}
type MatrixConnectorWithHTTPSettings interface {
MatrixConnector
GetHTTPClientSettings() exhttp.ClientSettings
}
type MatrixSendExtra struct {
Timestamp time.Time
MessageMeta *database.Message
@ -205,21 +176,12 @@ type MatrixAPI interface {
TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error
MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error
GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error)
}
type StreamOrderReadingMatrixAPI interface {
MatrixAPI
MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error
}
type MarkAsDMMatrixAPI interface {
MatrixAPI
MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error
}
type EphemeralSendingMatrixAPI interface {
MatrixAPI
BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error)
}

View file

@ -88,36 +88,6 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI,
rejectInvite(ctx, evt, intent, "")
}
func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) {
if portal.MXID == "" {
return
}
log := zerolog.Ctx(ctx)
existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID)
if err != nil {
log.Err(err).
Stringer("old_portal_mxid", portal.MXID).
Msg("Failed to check existing portal members, deleting room")
} else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok {
log.Debug().
Stringer("old_portal_mxid", portal.MXID).
Msg("Inviter has no member event in old portal, deleting room")
} else if targetUserMember.Membership.IsInviteOrJoin() {
return
} else {
log.Debug().
Stringer("old_portal_mxid", portal.MXID).
Str("membership", string(targetUserMember.Membership)).
Msg("Inviter is not in old portal, deleting room")
}
if err = portal.RemoveMXID(ctx); err != nil {
log.Err(err).Msg("Failed to delete old portal mxid")
} else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
log.Err(err).Msg("Failed to clean up old portal room")
}
}
func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult {
ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey()))
validator, ok := br.Network.(IdentifierValidatingNetwork)
@ -195,7 +165,34 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
}
portal.CleanupOrphanedDM(ctx, sender.MXID)
if portal.MXID != "" {
doCleanup := true
existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID)
if err != nil {
log.Err(err).
Stringer("old_portal_mxid", portal.MXID).
Msg("Failed to check existing portal members, deleting room")
} else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok {
log.Debug().
Stringer("old_portal_mxid", portal.MXID).
Msg("Inviter has no member event in old portal, deleting room")
} else if targetUserMember.Membership.IsInviteOrJoin() {
doCleanup = false
} else {
log.Debug().
Stringer("old_portal_mxid", portal.MXID).
Str("membership", string(targetUserMember.Membership)).
Msg("Inviter is not in old portal, deleting room")
}
if doCleanup {
if err = portal.RemoveMXID(ctx); err != nil {
log.Err(err).Msg("Failed to delete old portal mxid")
} else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
log.Err(err).Msg("Failed to clean up old portal room")
}
}
}
err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID())
if err != nil {
log.Err(err).Msg("Failed to ensure bot is invited to room")
@ -209,67 +206,72 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
return EventHandlingResultFailed
}
portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock()
portalMXID := portal.MXID
if portalMXID != "" {
sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL())
rejectInvite(ctx, evt, br.Bot, "")
return EventHandlingResultSuccess
}
err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
if err != nil {
log.Err(err).Msg("Failed to give permissions to bridge bot")
sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot")
rejectInvite(ctx, evt, br.Bot, "")
return EventHandlingResultSuccess
}
overrideIntent := invitedGhost.Intent
if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
log.Debug().
Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
Msg("Created DM was redirected to another user ID")
_, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
Parsed: &event.MemberEventContent{
Membership: event.MembershipLeave,
Reason: "Direct chat redirected to another internal user ID",
didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID)
if didSetPortal {
message := "Private chat portal created"
err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
hasWarning := false
if err != nil {
log.Warn().Err(err).Msg("Failed to give power to bot in new DM")
message += "\n\nWarning: failed to promote bot"
hasWarning = true
}
if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
log.Debug().
Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
Msg("Created DM was redirected to another user ID")
_, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
Parsed: &event.MemberEventContent{
Membership: event.MembershipLeave,
Reason: "Direct chat redirected to another internal user ID",
},
}, time.Time{})
if err != nil {
log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
}
otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo)
if err != nil {
log.Err(err).Msg("Failed to get ghost of real portal other user ID")
} else {
invitedGhost = otherUserGhost
}
}
if resp.PortalInfo != nil {
portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{})
} else {
portal.UpdateCapabilities(ctx, sourceLogin, true)
portal.UpdateBridgeInfo(ctx)
}
// TODO this might become unnecessary if UpdateInfo starts taking care of it
_, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
Parsed: &event.ElementFunctionalMembersContent{
ServiceMembers: []id.UserID{br.Bot.GetMXID()},
},
}, time.Time{})
if err != nil {
log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
log.Warn().Err(err).Msg("Failed to set service members in room")
if !hasWarning {
message += "\n\nWarning: failed to set service members"
hasWarning = true
}
}
if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot {
overrideIntent = br.Bot
} else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil {
log.Err(err).Msg("Failed to get ghost of real portal other user ID")
} else {
invitedGhost = otherUserGhost
overrideIntent = otherUserGhost.Intent
mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
if ok {
err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
if err != nil {
if hasWarning {
message += fmt.Sprintf(", %s", err.Error())
} else {
message += fmt.Sprintf("\n\nWarning: %s", err.Error())
}
}
}
sendNotice(ctx, evt, invitedGhost.Intent, message)
} else {
// TODO ensure user is invited even if PortalInfo wasn't provided?
sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL())
rejectInvite(ctx, evt, br.Bot, "")
}
err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{
// We locked it before checking the mxid
RoomCreateAlreadyLocked: true,
FailIfMXIDSet: true,
ChatInfo: resp.PortalInfo,
ChatInfoSource: sourceLogin,
})
if err != nil {
log.Err(err).Msg("Failed to update Matrix room ID for new DM portal")
sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work")
return EventHandlingResultSuccess
}
message := "Private chat portal created"
mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
if ok {
err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
if err != nil {
log.Err(err).Msg("Error in connector newly bridged room handler")
message += fmt.Sprintf("\n\nWarning: %s", err.Error())
}
}
sendNotice(ctx, evt, overrideIntent, message)
return EventHandlingResultSuccess
}
@ -292,3 +294,21 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith
}
return nil
}
func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock()
if portal.MXID != "" {
return false
}
portal.MXID = roomID
portal.updateLogger()
portal.Bridge.cacheLock.Lock()
portal.Bridge.portalsByMXID[portal.MXID] = portal
portal.Bridge.cacheLock.Unlock()
err := portal.Save(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid")
}
return true
}

View file

@ -20,7 +20,6 @@ import (
type MessageStatusEventInfo struct {
RoomID id.RoomID
TransactionID string
SourceEventID id.EventID
NewEventID id.EventID
EventType event.Type
@ -42,7 +41,6 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo {
return &MessageStatusEventInfo{
RoomID: evt.RoomID,
TransactionID: evt.Unsigned.TransactionID,
SourceEventID: evt.ID,
EventType: evt.Type,
MessageType: evt.Content.AsMessage().MsgType,
@ -184,10 +182,9 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe
Type: event.RelReference,
EventID: evt.SourceEventID,
},
TargetTxnID: evt.TransactionID,
Status: ms.Status,
Reason: ms.ErrorReason,
Message: ms.Message,
Status: ms.Status,
Reason: ms.ErrorReason,
Message: ms.Message,
}
if ms.InternalError != nil {
content.InternalError = ms.InternalError.Error()

View file

@ -47,8 +47,8 @@ type PortalID string
// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true.
// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user.
type PortalKey struct {
ID PortalID `json:"portal_id"`
Receiver UserLoginID `json:"portal_receiver,omitempty"`
ID PortalID
Receiver UserLoginID
}
func (pk PortalKey) IsEmpty() bool {

View file

@ -16,9 +16,7 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/configupgrade"
"go.mau.fi/util/ptr"
"go.mau.fi/util/random"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
@ -119,15 +117,11 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa
mediaPart.Content.EnsureHasHTML()
mediaPart.Content.Body += "\n\n" + textPart.Content.Body
mediaPart.Content.FormattedBody += "<br><br>" + textPart.Content.FormattedBody
mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions)
mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...)
} else {
mediaPart.Content.FileName = mediaPart.Content.Body
mediaPart.Content.Body = textPart.Content.Body
mediaPart.Content.Format = textPart.Content.Format
mediaPart.Content.FormattedBody = textPart.Content.FormattedBody
mediaPart.Content.Mentions = textPart.Content.Mentions
mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews
}
if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok {
metaMerger.CopyFrom(textPart.DBMetadata)
@ -261,7 +255,6 @@ type NetworkConnector interface {
}
type StoppableNetwork interface {
NetworkConnector
// Stop is called when the bridge is stopping, after all network clients have been disconnected.
Stop()
}
@ -318,16 +311,6 @@ type MaxFileSizeingNetwork interface {
SetMaxFileSize(maxSize int64)
}
type NetworkResettingNetwork interface {
NetworkConnector
// ResetHTTPTransport should recreate the HTTP client used by the bridge.
// It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable.
ResetHTTPTransport()
// ResetNetworkConnections should forcefully disconnect and restart any persistent network connections.
// ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here.
ResetNetworkConnections()
}
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
type MatrixMessageResponse struct {
@ -359,16 +342,10 @@ type NetworkGeneralCapabilities struct {
// Should the bridge re-request user info on incoming messages even if the ghost already has info?
// By default, info is only requested for ghosts with no name, and other updating is left to events.
AggressiveUpdateInfo bool
// Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message?
// This should be enabled if the network requires each message to be marked as read independently,
// and doesn't automatically do it when sending a message.
ImplicitReadReceipts bool
// If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave])
// to handle asynchronous message responses, this field can be set to enable
// automatic timeout errors in case the asynchronous response never arrives.
OutgoingMessageTimeouts *OutgoingTimeoutConfig
// Capabilities related to the provisioning API.
Provisioning ProvisioningCapabilities
}
// NetworkAPI is an interface representing a remote network client for a single user login.
@ -702,35 +679,6 @@ type RoomTopicHandlingNetworkAPI interface {
HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error)
}
type DisappearTimerChangingNetworkAPI interface {
NetworkAPI
// HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed.
// This method should update the Disappear field of the Portal with the new timer and return true
// if the change was successful. If the change is not successful, then the field should not be updated.
HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error)
}
// DeleteChatHandlingNetworkAPI is an optional interface that network connectors
// can implement to delete a chat from the remote network.
type DeleteChatHandlingNetworkAPI interface {
NetworkAPI
// HandleMatrixDeleteChat is called when the user explicitly deletes a chat.
HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error
}
// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors
// can implement to accept message requests from the remote network.
type MessageRequestAcceptingNetworkAPI interface {
NetworkAPI
// HandleMatrixAcceptMessageRequest is called when the user accepts a message request.
HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error
}
type BeeperAIStreamHandlingNetworkAPI interface {
NetworkAPI
HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error
}
type ResolveIdentifierResponse struct {
// Ghost is the ghost of the user that the identifier resolves to.
// This field should be set whenever possible. However, it is not required,
@ -750,8 +698,6 @@ type ResolveIdentifierResponse struct {
Chat *CreateChatResponse
}
var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10))
type CreateChatResponse struct {
PortalKey networkid.PortalKey
// Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary.
@ -760,17 +706,6 @@ type CreateChatResponse struct {
// If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user,
// this field should have the user ID of said different user.
DMRedirectedTo networkid.UserID
FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant
}
type CreateChatFailedParticipant struct {
Reason string `json:"reason"`
InviteEventType string `json:"invite_event_type,omitempty"`
InviteContent *event.Content `json:"invite_content,omitempty"`
UserMXID id.UserID `json:"user_mxid,omitempty"`
DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"`
}
// IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats.
@ -805,83 +740,7 @@ type UserSearchingNetworkAPI interface {
type GroupCreatingNetworkAPI interface {
IdentifierResolvingNetworkAPI
CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error)
}
type PersonalFilteringCustomizingNetworkAPI interface {
NetworkAPI
CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom)
}
type ProvisioningCapabilities struct {
ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"`
GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"`
}
type ResolveIdentifierCapabilities struct {
// Can DMs be created after resolving an identifier?
CreateDM bool `json:"create_dm"`
// Can users be looked up by phone number?
LookupPhone bool `json:"lookup_phone"`
// Can users be looked up by email address?
LookupEmail bool `json:"lookup_email"`
// Can users be looked up by network-specific username?
LookupUsername bool `json:"lookup_username"`
// Can any phone number be contacted without having to validate it via lookup first?
AnyPhone bool `json:"any_phone"`
// Can a contact list be retrieved from the bridge?
ContactList bool `json:"contact_list"`
// Can users be searched by name on the remote network?
Search bool `json:"search"`
}
type GroupTypeCapabilities struct {
TypeDescription string `json:"type_description"`
Name GroupFieldCapability `json:"name"`
Username GroupFieldCapability `json:"username"`
Avatar GroupFieldCapability `json:"avatar"`
Topic GroupFieldCapability `json:"topic"`
Disappear GroupFieldCapability `json:"disappear"`
Participants GroupFieldCapability `json:"participants"`
Parent GroupFieldCapability `json:"parent"`
}
type GroupFieldCapability struct {
// Is setting this field allowed at all in the create request?
// Even if false, the network connector should attempt to set the metadata after group creation,
// as the allowed flag can't be enforced properly when creating a group for an existing Matrix room.
Allowed bool `json:"allowed"`
// Is setting this field mandatory for the creation to succeed?
Required bool `json:"required,omitempty"`
// The minimum/maximum length of the field, if applicable.
// For members, length means the number of members excluding the creator.
MinLength int `json:"min_length,omitempty"`
MaxLength int `json:"max_length,omitempty"`
// Only for the disappear field: allowed disappearing settings
DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"`
// This can be used to tell provisionutil not to call ValidateUserID on each participant.
// It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs.
SkipIdentifierValidation bool `json:"-"`
}
type GroupCreateParams struct {
Type string `json:"type,omitempty"`
Username string `json:"username,omitempty"`
// Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs
Participants []networkid.UserID `json:"participants,omitempty"`
Parent *networkid.PortalKey `json:"parent,omitempty"`
Name *event.RoomNameEventContent `json:"name,omitempty"`
Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"`
Topic *event.TopicEventContent `json:"topic,omitempty"`
Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"`
// An existing room ID to bridge to. If unset, a new room will be created.
RoomID id.RoomID `json:"room_id,omitempty"`
CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error)
}
type MembershipChangeType struct {
@ -921,15 +780,16 @@ type MatrixMembershipChange struct {
MatrixRoomMeta[*event.MemberEventContent]
Target GhostOrUserLogin
Type MembershipChangeType
}
type MatrixMembershipResult struct {
RedirectTo networkid.UserID
// Deprecated: Use Target instead
TargetGhost *Ghost
// Deprecated: Use Target instead
TargetUserLogin *UserLogin
}
type MembershipHandlingNetworkAPI interface {
NetworkAPI
HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error)
HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error)
}
type SinglePowerLevelChange struct {
@ -1168,11 +1028,6 @@ type RemoteChatDelete interface {
RemoteDeleteOnlyForMe
}
type RemoteChatDeleteWithChildren interface {
RemoteChatDelete
DeleteChildren() bool
}
type RemoteEventThatMayCreatePortal interface {
RemoteEvent
ShouldCreatePortal() bool
@ -1405,14 +1260,12 @@ type MatrixMessageRemove struct {
type MatrixRoomMeta[ContentType any] struct {
MatrixEventBase[ContentType]
PrevContent ContentType
IsStateRequest bool
PrevContent ContentType
}
type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent]
type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent]
type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent]
type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer]
type MatrixReadReceipt struct {
Portal *Portal
@ -1427,8 +1280,6 @@ type MatrixReadReceipt struct {
LastRead time.Time
// The receipt metadata.
Receipt event.ReadReceipt
// Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt.
Implicit bool
}
type MatrixTyping struct {
@ -1442,9 +1293,6 @@ type MatrixViewingChat struct {
Portal *Portal
}
type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent]
type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent]
type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent]
type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent]
type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent]
type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent]

File diff suppressed because it is too large Load diff

View file

@ -194,9 +194,6 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t
if err != nil {
log.Err(err).Msg("Failed to get last thread message")
return
} else if anchorMessage == nil {
log.Warn().Msg("No messages found in thread?")
return
}
resp := portal.fetchThreadBackfill(ctx, source, anchorMessage)
if resp != nil {
@ -342,7 +339,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
for i, part := range msg.Parts {
partIDs = append(partIDs, part.ID)
portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent)
part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent()
evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID)
dbMessage := &database.Message{
ID: msg.ID,
@ -383,23 +379,19 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
prevThreadEvent.MXID = evtID
out.PrevThreadEvents[*msg.ThreadRoot] = evtID
}
if msg.Disappear.Type != event.DisappearingTypeNone {
if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
if msg.Disappear.Type != database.DisappearingTypeNone {
if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer)
}
out.Disappear = append(out.Disappear, &database.DisappearingMessage{
RoomID: portal.MXID,
EventID: evtID,
Timestamp: msg.Timestamp,
DisappearingSetting: msg.Disappear,
})
}
}
slices.Sort(partIDs)
for _, reaction := range msg.Reactions {
if reaction == nil {
continue
}
reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove)
if !ok {
continue
@ -410,7 +402,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
if reaction.Timestamp.IsZero() {
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
}
//lint:ignore SA4006 it's a todo
targetPart, ok := partMap[*reaction.TargetPart]
if !ok {
// TODO warning log and/or skip reaction?

View file

@ -37,8 +37,8 @@ func (portal *PortalInternals) EventLoop() {
(*Portal)(portal).eventLoop()
}
func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) {
return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt)
func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) {
return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt)
}
func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context {
@ -49,10 +49,6 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any
(*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback)
}
func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
return (*Portal)(portal).unwrapBeeperSendState(ctx, evt)
}
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) {
(*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID)
}
@ -65,8 +61,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
}
func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest)
func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
}
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult {
@ -77,10 +73,6 @@ func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user
(*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt)
}
func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) {
(*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal)
}
func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixTyping(ctx, evt)
}
@ -125,24 +117,12 @@ func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.User
return (*Portal)(portal).getTargetUser(ctx, userID)
}
func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt)
func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt)
}
func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest)
}
func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest)
}
func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixTombstone(ctx, evt)
}
func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) {
(*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser)
func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt)
}
func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
@ -153,10 +133,6 @@ func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *Us
return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt)
}
func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) {
(*Portal)(portal).ensureFunctionalMember(ctx, ghost)
}
func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) {
return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType)
}
@ -257,10 +233,6 @@ func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, sourc
return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt)
}
func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) {
return (*Portal)(portal).findOtherLogins(ctx, source)
}
func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult {
return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt)
}
@ -269,16 +241,16 @@ func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source
return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill)
}
func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline)
func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool {
return (*Portal)(portal).updateName(ctx, name, sender, ts)
}
func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline)
func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool {
return (*Portal)(portal).updateTopic(ctx, topic, sender, ts)
}
func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline)
func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool {
return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts)
}
func (portal *PortalInternals) GetBridgeInfoStateKey() string {
@ -293,12 +265,8 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen
return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts)
}
func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool {
return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra)
}
func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) {
(*Portal)(portal).revertRoomMeta(ctx, evt)
func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool {
return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content)
}
func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) {
@ -309,10 +277,6 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha
return (*Portal)(portal).updateOtherUser(ctx, members)
}
func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool {
return (*Portal)(portal).roomIsPublic(ctx)
}
func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error {
return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts)
}
@ -333,10 +297,6 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc
return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle)
}
func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) {
(*Portal)(portal).addToUserSpaces(ctx)
}
func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) {
(*Portal)(portal).removeInPortalCache(ctx)
}
@ -400,3 +360,7 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save
func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error {
return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove)
}
func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID)
}

View file

@ -32,40 +32,21 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
if source == target {
return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same")
}
log := zerolog.Ctx(ctx).With().
Str("action", "re-id portal").
Stringer("source_portal_key", source).
Stringer("target_portal_key", target).
Logger()
ctx = log.WithContext(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msg("Re-ID'ing portal")
defer func() {
log.Debug().Msg("Finished handling portal re-ID")
}()
acquireCacheLock := func() {
if !br.cacheLock.TryLock() {
log.Debug().Msg("Waiting for global cache lock")
br.cacheLock.Lock()
log.Debug().Msg("Acquired global cache lock after waiting")
} else {
log.Trace().Msg("Acquired global cache lock without waiting")
}
}
log.Debug().Msg("Re-ID'ing portal")
sourcePortal, err := br.GetExistingPortalByKey(ctx, source)
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err)
} else if sourcePortal == nil {
log.Debug().Msg("Source portal not found, re-ID is no-op")
return ReIDResultNoOp, nil, nil
}
if !sourcePortal.roomCreateLock.TryLock() {
if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
(*cancelCreate)()
}
log.Debug().Msg("Waiting for source portal room creation lock")
sourcePortal.roomCreateLock.Lock()
log.Debug().Msg("Acquired source portal room creation lock after waiting")
}
sourcePortal.roomCreateLock.Lock()
defer sourcePortal.roomCreateLock.Unlock()
if sourcePortal.MXID == "" {
log.Info().Msg("Source portal doesn't have Matrix room, deleting row")
@ -78,37 +59,22 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Stringer("source_portal_mxid", sourcePortal.MXID)
})
acquireCacheLock()
targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true)
if err != nil {
br.cacheLock.Unlock()
return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err)
}
if targetPortal == nil {
log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal")
err = sourcePortal.unlockedReID(ctx, target)
br.cacheLock.Unlock()
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err)
}
return ReIDResultSourceReIDd, sourcePortal, nil
}
br.cacheLock.Unlock()
if !targetPortal.roomCreateLock.TryLock() {
if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
(*cancelCreate)()
}
log.Debug().Msg("Waiting for target portal room creation lock")
targetPortal.roomCreateLock.Lock()
log.Debug().Msg("Acquired target portal room creation lock after waiting")
}
targetPortal.roomCreateLock.Lock()
defer targetPortal.roomCreateLock.Unlock()
if targetPortal.MXID == "" {
log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal")
acquireCacheLock()
defer br.cacheLock.Unlock()
err = targetPortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err)
@ -123,9 +89,6 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
return c.Stringer("target_portal_mxid", targetPortal.MXID)
})
log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal")
sourcePortal.removeInPortalCache(ctx)
acquireCacheLock()
defer br.cacheLock.Unlock()
err = sourcePortal.unlockedDelete(ctx)
if err != nil {
return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err)
@ -133,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
go func() {
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
Parsed: &event.TombstoneEventContent{
Body: "This room has been merged",
Body: fmt.Sprintf("This room has been merged"),
ReplacementRoom: targetPortal.MXID,
},
}, time.Now())

View file

@ -1,149 +0,0 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package provisionutil
import (
"context"
"github.com/rs/zerolog"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type RespCreateGroup struct {
ID networkid.PortalID `json:"id"`
MXID id.RoomID `json:"mxid"`
Portal *bridgev2.Portal `json:"-"`
FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"`
}
func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) {
api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI)
if !ok {
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups"))
}
zerolog.Ctx(ctx).Debug().
Any("create_params", params).
Msg("Creating group chat on remote network")
caps := login.Bridge.Network.GetCapabilities()
typeSpec, validType := caps.Provisioning.GroupCreation[params.Type]
if !validType {
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type))
}
if len(params.Participants) < typeSpec.Participants.MinLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength))
} else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength))
}
userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
for i, participant := range params.Participants {
parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant))
if ok {
participant = parsedParticipant
params.Participants[i] = participant
}
if !typeSpec.Participants.SkipIdentifierValidation {
if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant))
}
}
if api.IsThisUser(ctx, participant) {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant))
}
}
if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required"))
} else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength))
} else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength))
}
if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required"))
}
if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required"))
} else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength))
} else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength))
}
if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required"))
} else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer"))
}
if params.Username == "" && typeSpec.Username.Required {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required"))
} else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength))
} else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength))
}
if params.Parent == nil && typeSpec.Parent.Required {
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required"))
}
resp, err := api.CreateGroup(ctx, params)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to create group")
return nil, err
}
if resp.PortalKey.IsEmpty() {
return nil, ErrNoPortalKey
}
zerolog.Ctx(ctx).Debug().
Object("portal_key", resp.PortalKey).
Msg("Successfully created group on remote network")
if resp.Portal == nil {
resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
}
}
if resp.Portal.MXID == "" {
err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
}
}
for key, fp := range resp.FailedParticipants {
if fp.InviteEventType == "" {
fp.InviteEventType = event.EventMessage.Type
}
if fp.UserMXID == "" {
ghost, err := login.Bridge.GetGhostByID(ctx, key)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant")
} else if ghost != nil {
fp.UserMXID = ghost.Intent.GetMXID()
}
}
if fp.DMRoomMXID == "" {
portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant")
} else if portal != nil {
fp.DMRoomMXID = portal.MXID
}
}
}
return &RespCreateGroup{
ID: resp.Portal.ID,
MXID: resp.Portal.MXID,
Portal: resp.Portal,
FailedParticipants: resp.FailedParticipants,
}, nil
}

View file

@ -1,98 +0,0 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package provisionutil
import (
"context"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
)
type RespGetContactList struct {
Contacts []*RespResolveIdentifier `json:"contacts"`
}
type RespSearchUsers struct {
Results []*RespResolveIdentifier `json:"results"`
}
func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) {
api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
if !ok {
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts"))
}
resp, err := api.GetContactList(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
return nil, err
}
return &RespGetContactList{
Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false),
}, nil
}
func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) {
api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
if !ok {
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users"))
}
resp, err := api.SearchUsers(ctx, query)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
return nil, err
}
return &RespSearchUsers{
Results: processResolveIdentifiers(ctx, login.Bridge, resp, true),
}, nil
}
func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) {
apiResp = make([]*RespResolveIdentifier, len(resp))
for i, contact := range resp {
apiContact := &RespResolveIdentifier{
ID: contact.UserID,
}
apiResp[i] = apiContact
if contact.UserInfo != nil {
if contact.UserInfo.Name != nil {
apiContact.Name = *contact.UserInfo.Name
}
if contact.UserInfo.Identifiers != nil {
apiContact.Identifiers = contact.UserInfo.Identifiers
}
}
if contact.Ghost != nil {
if syncInfo && contact.UserInfo != nil {
contact.Ghost.UpdateInfo(ctx, contact.UserInfo)
}
if contact.Ghost.Name != "" {
apiContact.Name = contact.Ghost.Name
}
if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
apiContact.Identifiers = contact.Ghost.Identifiers
}
apiContact.AvatarURL = contact.Ghost.AvatarMXC
apiContact.MXID = contact.Ghost.Intent.GetMXID()
}
if contact.Chat != nil {
if contact.Chat.Portal == nil {
var err error
contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
}
}
if contact.Chat.Portal != nil {
apiContact.DMRoomID = contact.Chat.Portal.MXID
}
}
}
return
}

View file

@ -1,125 +0,0 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package provisionutil
import (
"context"
"errors"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/id"
)
type RespResolveIdentifier struct {
ID networkid.UserID `json:"id"`
Name string `json:"name,omitempty"`
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
Identifiers []string `json:"identifiers,omitempty"`
MXID id.UserID `json:"mxid,omitempty"`
DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
Portal *bridgev2.Portal `json:"-"`
Ghost *bridgev2.Ghost `json:"-"`
JustCreated bool `json:"-"`
}
var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request")
func ResolveIdentifier(
ctx context.Context,
login *bridgev2.UserLogin,
identifier string,
createChat bool,
) (*RespResolveIdentifier, error) {
api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
if !ok {
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers"))
}
var resp *bridgev2.ResolveIdentifierResponse
parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier))
validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
if ok && (!vOK || validator.ValidateUserID(parsedUserID)) {
ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID")
return nil, err
}
resp = &bridgev2.ResolveIdentifierResponse{
Ghost: ghost,
UserID: parsedUserID,
}
gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI)
if ok && createChat {
resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat")
return nil, err
}
} else if createChat || ghost.Name == "" {
zerolog.Ctx(ctx).Debug().
Bool("create_chat", createChat).
Bool("has_name", ghost.Name != "").
Msg("Falling back to resolving identifier")
resp = nil
identifier = string(parsedUserID)
}
}
if resp == nil {
var err error
resp, err = api.ResolveIdentifier(ctx, identifier, createChat)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier")
return nil, err
} else if resp == nil {
return nil, nil
}
}
apiResp := &RespResolveIdentifier{
ID: resp.UserID,
Ghost: resp.Ghost,
}
if resp.Ghost != nil {
if resp.UserInfo != nil {
resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
}
apiResp.Name = resp.Ghost.Name
apiResp.AvatarURL = resp.Ghost.AvatarMXC
apiResp.Identifiers = resp.Ghost.Identifiers
apiResp.MXID = resp.Ghost.Intent.GetMXID()
} else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
apiResp.Name = *resp.UserInfo.Name
}
if resp.Chat != nil {
if resp.Chat.PortalKey.IsEmpty() {
return nil, ErrNoPortalKey
}
if resp.Chat.Portal == nil {
var err error
resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
}
}
resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID)
if createChat && resp.Chat.Portal.MXID == "" {
apiResp.JustCreated = true
err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
}
}
apiResp.Portal = resp.Chat.Portal
apiResp.DMRoomID = resp.Chat.Portal.MXID
}
return apiResp, nil
}

View file

@ -63,13 +63,6 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve
return true
}
var (
ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage())
ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage()
)
func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult {
// TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands
@ -85,11 +78,13 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
return EventHandlingResultFailed
} else if sender == nil {
log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event")
br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
return EventHandlingResultFailed
} else if !sender.Permissions.SendEvents {
if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") {
br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt))
status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
}
return EventHandlingResultIgnored
} else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") {
@ -97,7 +92,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
}
} else if evt.Type.Class != event.EphemeralEventType {
log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event")
br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
return EventHandlingResultIgnored
}
if evt.Type == event.EventMessage && sender != nil {
@ -106,7 +102,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
msg.RemovePerMessageProfileFallback()
if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom {
if !sender.Permissions.Commands {
br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt))
status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
return EventHandlingResultIgnored
}
go br.Commands.Handle(
@ -160,27 +157,10 @@ type EventHandlingResult struct {
Ignored bool
Queued bool
SkipStateEcho bool
// Error is an optional reason for failure. It is not required, Success may be false even without a specific error.
Error error
// Whether the Error should be sent as a MSS event.
SendMSS bool
// EventID from the network
EventID id.EventID
// Stream order from the network
StreamOrder int64
}
func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult {
ehr.EventID = id
return ehr
}
func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult {
ehr.StreamOrder = order
return ehr
}
func (ehr EventHandlingResult) WithError(err error) EventHandlingResult {
@ -197,11 +177,6 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult {
return ehr
}
func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult {
ehr.SkipStateEcho = skip
return ehr
}
func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult {
if err == nil {
return ehr
@ -220,7 +195,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult {
return ul.Bridge.QueueRemoteEvent(ul, evt)
}
func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult {
func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) {
log := login.Log
ctx := log.WithContext(br.BackgroundCtx)
maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver)
@ -236,14 +211,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandl
if err != nil {
log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain).
Msg("Failed to get portal to handle remote event")
return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err))
return
} else if portal == nil {
log.Warn().
Stringer("event_type", evt.GetType()).
Object("portal_key", key).
Bool("uncertain_receiver", isUncertain).
Msg("Portal not found to handle remote event")
return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler)
return
}
// TODO put this in a better place, and maybe cache to avoid constant db queries
login.MarkInPortal(ctx, portal)

View file

@ -65,19 +65,14 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal)
type ChatDelete struct {
EventMeta
OnlyForMe bool
Children bool
}
var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil)
var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil)
func (evt *ChatDelete) DeleteOnlyForMe() bool {
return evt.OnlyForMe
}
func (evt *ChatDelete) DeleteChildren() bool {
return evt.Children
}
// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange].
type ChatInfoChange struct {
EventMeta

View file

@ -59,41 +59,6 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID {
return evt.TransactionID
}
// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data.
type PreConvertedMessage struct {
EventMeta
Data *bridgev2.ConvertedMessage
ID networkid.MessageID
TransactionID networkid.TransactionID
HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error)
}
var (
_ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil)
_ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil)
_ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil)
)
func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) {
return evt.Data, nil
}
func (evt *PreConvertedMessage) GetID() networkid.MessageID {
return evt.ID
}
func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID {
return evt.TransactionID
}
func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) {
if evt.HandleExistingFunc == nil {
return bridgev2.UpsertResult{}, nil
}
return evt.HandleExistingFunc(ctx, portal, intent, existing)
}
type MessageRemove struct {
EventMeta

View file

@ -101,18 +101,6 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E
return evt
}
func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta {
origFunc := evt.LogContext
if origFunc == nil {
evt.LogContext = f
return evt
}
evt.LogContext = func(c zerolog.Context) zerolog.Context {
return f(origFunc(c))
}
return evt
}
func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta {
evt.PortalKey = p
return evt

View file

@ -164,17 +164,14 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) {
ul.UserMXID: 50,
},
},
Invite: []id.UserID{ul.UserMXID},
RoomVersion: id.RoomV11,
Invite: []id.UserID{ul.UserMXID},
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{ul.UserMXID}
// TODO remove this after initial_members is supported in hungryserv
req.BeeperAutoJoinInvites = true
}
pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI)
if ok {
pfc.CustomizePersonalFilteringSpace(req)
}
ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to create space room: %w", err)

View file

@ -19,10 +19,9 @@ import (
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -88,8 +87,6 @@ type RemoteProfile struct {
Username string `json:"username,omitempty"`
Name string `json:"name,omitempty"`
Avatar id.ContentURIString `json:"avatar,omitempty"`
AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"`
}
func coalesce[T ~string](a, b T) T {
@ -105,14 +102,11 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile {
other.Username = coalesce(rp.Username, other.Username)
other.Name = coalesce(rp.Name, other.Name)
other.Avatar = coalesce(rp.Avatar, other.Avatar)
if rp.AvatarFile != nil {
other.AvatarFile = rp.AvatarFile
}
return other
}
func (rp *RemoteProfile) IsZero() bool {
return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil)
func (rp *RemoteProfile) IsEmpty() bool {
return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "")
}
type BridgeState struct {
@ -126,10 +120,10 @@ type BridgeState struct {
UserAction BridgeStateUserAction `json:"user_action,omitempty"`
UserID id.UserID `json:"user_id,omitempty"`
RemoteID networkid.UserLoginID `json:"remote_id,omitempty"`
RemoteName string `json:"remote_name,omitempty"`
RemoteProfile RemoteProfile `json:"remote_profile,omitzero"`
UserID id.UserID `json:"user_id,omitempty"`
RemoteID string `json:"remote_id,omitempty"`
RemoteName string `json:"remote_name,omitempty"`
RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"`
Reason string `json:"reason,omitempty"`
Info map[string]interface{} `json:"info,omitempty"`
@ -209,7 +203,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool {
pong.StateEvent == newPong.StateEvent &&
pong.RemoteName == newPong.RemoteName &&
pong.UserAction == newPong.UserAction &&
pong.RemoteProfile == newPong.RemoteProfile &&
ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) &&
pong.Error == newPong.Error &&
maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) &&
pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now())

View file

@ -176,10 +176,6 @@ func (user *User) GetUserLogins() []*UserLogin {
return maps.Values(user.logins)
}
func (user *User) HasTooManyLogins() bool {
return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins
}
func (user *User) GetFormattedUserLogins() string {
user.Bridge.cacheLock.Lock()
logins := make([]string, len(user.logins))
@ -229,8 +225,9 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) {
user.MXID: 50,
},
},
Invite: []id.UserID{user.MXID},
IsDirect: true,
RoomVersion: id.RoomV11,
Invite: []id.UserID{user.MXID},
IsDirect: true,
}
if autoJoin {
req.BeeperInitialMembers = []id.UserID{user.MXID}

View file

@ -10,7 +10,6 @@ import (
"cmp"
"context"
"fmt"
"maps"
"slices"
"sync"
"time"
@ -51,8 +50,6 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
// TODO if loading the user caused the provided userlogin to be loaded, cancel here?
// Currently this will double-load it
}
userLogin := &UserLogin{
UserLogin: dbUserLogin,
@ -143,12 +140,6 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin {
return br.userLoginsByID[id]
}
func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
return slices.Collect(maps.Values(br.userLoginsByID))
}
func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) {
br.cacheLock.Lock()
defer br.cacheLock.Unlock()
@ -510,9 +501,9 @@ var _ status.BridgeStateFiller = (*UserLogin)(nil)
func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState {
state.UserID = ul.UserMXID
state.RemoteID = ul.ID
state.RemoteID = string(ul.ID)
state.RemoteName = ul.RemoteName
state.RemoteProfile = ul.RemoteProfile
state.RemoteProfile = &ul.RemoteProfile
filler, ok := ul.Client.(status.BridgeStateFiller)
if ok {
return filler.FillBridgeState(state)

521
client.go
View file

@ -111,8 +111,6 @@ type Client struct {
// Set to true to disable automatically sleeping on 429 errors.
IgnoreRateLimit bool
ResponseSizeLimit int64
txnID int32
// Should the ?user_id= query parameter be set in requests?
@ -145,8 +143,6 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName)
}
const WellKnownMaxSize = 64 * 1024
func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) {
wellKnownURL := url.URL{
Scheme: "https",
@ -172,15 +168,11 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve
if resp.StatusCode == http.StatusNotFound {
return nil, nil
} else if resp.ContentLength > WellKnownMaxSize {
return nil, errors.New(".well-known response too large")
}
data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize))
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
} else if len(data) >= WellKnownMaxSize {
return nil, errors.New(".well-known response too large")
}
var wellKnown ClientWellKnown
@ -331,7 +323,6 @@ const (
LogBodyContextKey contextKey = iota
LogRequestIDContextKey
MaxAttemptsContextKey
SyncTokenContextKey
)
func (cli *Client) RequestStart(req *http.Request) {
@ -386,14 +377,7 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
}
}
if body := req.Context().Value(LogBodyContextKey); body != nil {
switch typedLogBody := body.(type) {
case json.RawMessage:
evt.RawJSON("req_body", typedLogBody)
case string:
evt.Str("req_body", typedLogBody)
default:
panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body))
}
evt.Interface("req_body", body)
}
if errors.Is(err, context.Canceled) {
evt.Msg("Request canceled")
@ -410,43 +394,32 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin
return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody})
}
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error)
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error)
type FullRequest struct {
Method string
URL string
Headers http.Header
RequestJSON interface{}
RequestBytes []byte
RequestBody io.Reader
RequestLength int64
ResponseJSON interface{}
MaxAttempts int
BackoffDuration time.Duration
SensitiveContent bool
Handler ClientResponseHandler
DontReadResponse bool
ResponseSizeLimit int64
Logger *zerolog.Logger
Client *http.Client
Method string
URL string
Headers http.Header
RequestJSON interface{}
RequestBytes []byte
RequestBody io.Reader
RequestLength int64
ResponseJSON interface{}
MaxAttempts int
BackoffDuration time.Duration
SensitiveContent bool
Handler ClientResponseHandler
DontReadResponse bool
Logger *zerolog.Logger
Client *http.Client
}
var requestID int32
var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes"
func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) {
reqID := atomic.AddInt32(&requestID, 1)
logger := zerolog.Ctx(ctx)
if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
logger = params.Logger
}
ctx = logger.With().
Int32("req_id", reqID).
Logger().WithContext(ctx)
var logBody any
var reqBody io.Reader
var reqLen int64
reqBody := params.RequestBody
if params.RequestJSON != nil {
jsonStr, err := json.Marshal(params.RequestJSON)
if err != nil {
@ -457,38 +430,33 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
}
if params.SensitiveContent && !logSensitiveContent {
logBody = "<sensitive content omitted>"
} else if len(jsonStr) > 32768 {
logBody = fmt.Sprintf("<large content omitted (%d bytes)>", len(jsonStr))
} else {
logBody = json.RawMessage(jsonStr)
logBody = params.RequestJSON
}
reqBody = bytes.NewReader(jsonStr)
reqLen = int64(len(jsonStr))
} else if params.RequestBytes != nil {
logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes))
reqBody = bytes.NewReader(params.RequestBytes)
reqLen = int64(len(params.RequestBytes))
} else if params.RequestBody != nil {
logBody = "<unknown stream of bytes>"
reqLen = -1
if params.RequestLength > 0 {
logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
reqLen = params.RequestLength
} else if params.RequestLength == 0 {
zerolog.Ctx(ctx).Warn().
Msg("RequestBody passed without specifying request length")
}
reqBody = params.RequestBody
params.RequestLength = int64(len(params.RequestBytes))
} else if params.RequestLength > 0 && params.RequestBody != nil {
logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok {
// Prevent HTTP from closing the request body, it might be needed for retries
reqBody = nopCloseSeeker{rsc}
}
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
params.RequestJSON = struct{}{}
logBody = json.RawMessage("{}")
logBody = params.RequestJSON
reqBody = bytes.NewReader([]byte("{}"))
reqLen = 2
}
reqID := atomic.AddInt32(&requestID, 1)
logger := zerolog.Ctx(ctx)
if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
logger = params.Logger
}
ctx = logger.With().
Int32("req_id", reqID).
Logger().WithContext(ctx)
ctx = context.WithValue(ctx, LogBodyContextKey, logBody)
ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID))
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
@ -504,7 +472,9 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
if params.RequestJSON != nil {
req.Header.Set("Content-Type", "application/json")
}
req.ContentLength = reqLen
if params.RequestLength > 0 && params.RequestBody != nil {
req.ContentLength = params.RequestLength
}
return req, nil
}
@ -555,25 +525,10 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque
if len(cli.AccessToken) > 0 {
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
}
if params.ResponseSizeLimit == 0 {
params.ResponseSizeLimit = cli.ResponseSizeLimit
}
if params.ResponseSizeLimit == 0 {
params.ResponseSizeLimit = DefaultResponseSizeLimit
}
if params.Client == nil {
params.Client = cli.Client
}
return cli.executeCompiledRequest(
req,
params.MaxAttempts-1,
params.BackoffDuration,
params.ResponseJSON,
params.Handler,
params.DontReadResponse,
params.ResponseSizeLimit,
params.Client,
)
return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client)
}
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
@ -584,17 +539,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
func (cli *Client) doRetry(
req *http.Request,
cause error,
retries int,
backoff time.Duration,
responseJSON any,
handler ClientResponseHandler,
dontReadResponse bool,
sizeLimit int64,
client *http.Client,
) ([]byte, *http.Response, error) {
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
var err error
@ -623,30 +568,16 @@ func (cli *Client) doRetry(
select {
case <-time.After(backoff):
case <-req.Context().Done():
if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) {
return nil, nil, req.Context().Err()
}
return nil, nil, req.Context().Err()
}
if cli.UpdateRequestOnRetry != nil {
req = cli.UpdateRequestOnRetry(req, cause)
}
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client)
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client)
}
func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) {
if res.ContentLength > limit {
return nil, HTTPError{
Request: req,
Response: res,
Message: "not reading response",
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
}
}
contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1))
if err == nil && len(contents) > int(limit) {
err = ErrBodyReadReachedLimit
}
func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) {
contents, err := io.ReadAll(res.Body)
if err != nil {
return nil, HTTPError{
Request: req,
@ -667,20 +598,17 @@ func closeTemp(log *zerolog.Logger, file *os.File) {
}
}
func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
log := zerolog.Ctx(req.Context())
file, err := os.CreateTemp("", "mautrix-response-")
if err != nil {
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
_, err = handleNormalResponse(req, res, responseJSON, limit)
_, err = handleNormalResponse(req, res, responseJSON)
return nil, err
}
defer closeTemp(log, file)
var n int64
if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil {
if _, err = io.Copy(file, res.Body); err != nil {
return nil, fmt.Errorf("failed to copy response to file: %w", err)
} else if n > limit {
return nil, ErrBodyReadReachedLimit
} else if _, err = file.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err)
} else if err = json.NewDecoder(file).Decode(responseJSON); err != nil {
@ -690,12 +618,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON any, lim
}
}
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
return nil, nil
}
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
if contents, err := readResponseBody(req, res, limit); err != nil {
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
if contents, err := readResponseBody(req, res); err != nil {
return nil, err
} else if responseJSON == nil {
return contents, nil
@ -713,13 +641,8 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON an
}
}
const ErrorResponseSizeLimit = 512 * 1024
var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024
func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
defer res.Body.Close()
contents, err := readResponseBody(req, res, ErrorResponseSizeLimit)
contents, err := readResponseBody(req, res)
if err != nil {
return contents, err
}
@ -738,31 +661,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
func (cli *Client) executeCompiledRequest(
req *http.Request,
retries int,
backoff time.Duration,
responseJSON any,
handler ClientResponseHandler,
dontReadResponse bool,
sizeLimit int64,
client *http.Client,
) ([]byte, *http.Response, error) {
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := client.Do(req)
duration := time.Since(startTime)
duration := time.Now().Sub(startTime)
if res != nil && !dontReadResponse {
defer res.Body.Close()
}
if err != nil {
// Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry
canRetry := !errors.Is(err, context.Canceled) ||
errors.Is(context.Cause(req.Context()), ErrContextCancelRetry)
if retries > 0 && canRetry {
return cli.doRetry(
req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
)
if retries > 0 && !errors.Is(err, context.Canceled) {
return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client)
}
err = HTTPError{
Request: req,
@ -777,9 +686,7 @@ func (cli *Client) executeCompiledRequest(
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doRetry(
req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
)
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client)
}
var body []byte
@ -787,7 +694,7 @@ func (cli *Client) executeCompiledRequest(
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
} else {
body, err = handler(req, res, responseJSON, sizeLimit)
body, err = handler(req, res, responseJSON)
cli.LogRequestDone(req, res, nil, err, len(body), duration)
}
return body, res, err
@ -847,7 +754,7 @@ func (req *ReqSync) BuildQuery() map[string]string {
query["full_state"] = "true"
}
if req.UseStateAfter {
query["use_state_after"] = "true"
query["org.matrix.msc4222.use_state_after"] = "true"
}
if req.BeeperStreaming {
query["com.beeper.streaming"] = "true"
@ -871,7 +778,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
}
start := time.Now()
_, err = cli.MakeFullRequest(ctx, fullReq)
duration := time.Since(start)
duration := time.Now().Sub(start)
timeout := time.Duration(req.Timeout) * time.Millisecond
buffer := 10 * time.Second
if req.Since == "" {
@ -918,7 +825,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp
return
}
func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
var bodyBytes []byte
bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
@ -942,7 +849,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[an
// Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
//
// Registers with kind=user. For kind=guest, see RegisterGuest.
func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
u := cli.BuildClientURL("v3", "register")
return cli.register(ctx, u, req)
}
@ -951,7 +858,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRe
// with kind=guest.
//
// For kind=user, see Register.
func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
query := map[string]string{
"kind": "guest",
}
@ -974,8 +881,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*R
// panic(err)
// }
// token := res.AccessToken
func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) {
_, uia, err := cli.Register(ctx, req)
func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
res, uia, err := cli.Register(ctx, req)
if err != nil && uia == nil {
return nil, err
} else if uia == nil {
@ -984,7 +891,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*R
return nil, errors.New("server does not support m.login.dummy")
}
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
res, _, err := cli.Register(ctx, req)
res, _, err = cli.Register(ctx, req)
if err != nil {
return nil, err
}
@ -1148,19 +1055,8 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs
return
}
func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) {
urlPath := cli.BuildClientURL("v3", "user_directory", "search")
_, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{
SearchTerm: query,
Limit: limit,
}, &resp)
return
}
func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) {
supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms)
supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms)
if cli.SpecVersions != nil && !supportsUnstable && !supportsStable {
if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) {
err = fmt.Errorf("server does not support fetching mutual rooms")
return
}
@ -1170,10 +1066,7 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex
if len(extras) > 0 {
query["from"] = extras[0].From
}
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query)
if !supportsStable && supportsUnstable {
urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
}
urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
@ -1195,7 +1088,8 @@ func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via
// GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname
func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) {
err = cli.GetProfileField(ctx, mxid, "displayname", &resp)
urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname")
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
@ -1206,47 +1100,41 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay
// SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname
func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) {
return cli.SetProfileField(ctx, "displayname", displayName)
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname")
s := struct {
DisplayName string `json:"displayname"`
}{displayName}
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil)
return
}
// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) {
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
}
// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) {
urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{
key: value,
}, nil)
return
}
// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) {
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
}
// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) {
urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return
}
// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname
func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) {
urlPath := cli.BuildClientURL("v3", "profile", userID, key)
if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
}
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into)
return
}
// GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url
func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) {
urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url")
s := struct {
AvatarURL id.ContentURI `json:"avatar_url"`
}{}
err = cli.GetProfileField(ctx, mxid, "avatar_url", &s)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s)
if err != nil {
return
}
url = s.AvatarURL
return
}
@ -1338,9 +1226,6 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
if req.UnstableStickyDuration > 0 {
queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
}
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
var isEncrypted bool
@ -1364,51 +1249,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
return
}
// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint.
// contentJSON should be a value that can be encoded as JSON using json.Marshal.
func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
var req ReqSendEvent
if len(extra) > 0 {
req = extra[0]
}
var txnID string
if len(req.TransactionID) > 0 {
txnID = req.TransactionID
} else {
txnID = cli.TxnID()
}
queryParams := map[string]string{}
if req.Timestamp > 0 {
queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
}
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted {
var isEncrypted bool
isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to check if room is encrypted: %w", err)
return
}
if isEncrypted {
if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
err = fmt.Errorf("failed to encrypt event: %w", err)
return
}
eventType = event.EventEncrypted
}
}
urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID}
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
return
}
// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
var req ReqSendEvent
if len(extra) > 0 {
req = extra[0]
@ -1418,18 +1261,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
if req.MeowEventID != "" {
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
}
if req.TransactionID != "" {
queryParams["fi.mau.transaction_id"] = req.TransactionID
}
if req.UnstableDelay > 0 {
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
}
if req.UnstableStickyDuration > 0 {
queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
}
if req.Timestamp > 0 {
queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
}
urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
@ -1442,38 +1276,14 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
// SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
//
// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead.
func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{
Timestamp: ts,
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
"ts": strconv.FormatInt(ts, 10),
})
return
}
func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) {
query := map[string]string{}
if req.DelayID != "" {
query["delay_id"] = string(req.DelayID)
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
}
if req.Status != "" {
query["status"] = string(req.Status)
}
if req.NextBatch != "" {
query["next_batch"] = req.NextBatch
}
urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp)
// Migration: merge old keys with new ones
if resp != nil {
resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...)
resp.DelayedEvents = nil
resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...)
resp.FinalisedEvents = nil
}
return
}
@ -1766,20 +1576,11 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy
}
// parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map.
func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
if res.ContentLength > limit {
return nil, HTTPError{
Request: req,
Response: res,
Message: "not reading response",
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
}
}
func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
response := make(RoomStateMap)
responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event)
*responsePtr = response
dec := json.NewDecoder(io.LimitReader(res.Body, limit))
dec := json.NewDecoder(res.Body)
arrayStart, err := dec.Token()
if err != nil {
@ -1813,8 +1614,6 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any
return nil, nil
}
type RoomStateMap = map[event.Type]map[string]*event.Event
// State gets all state in a room.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {
@ -1897,9 +1696,6 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa
}
func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
if mxcURL.IsEmpty() {
return nil, fmt.Errorf("empty mxc uri provided to Download")
}
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
Method: http.MethodGet,
URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID),
@ -1908,41 +1704,6 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re
return resp, err
}
type DownloadThumbnailExtra struct {
Method string
Animated bool
}
func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) {
if mxcURL.IsEmpty() {
return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail")
}
if len(extras) > 1 {
panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras)))
}
var extra DownloadThumbnailExtra
if len(extras) == 1 {
extra = extras[0]
}
path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID}
query := map[string]string{
"height": strconv.Itoa(height),
"width": strconv.Itoa(width),
}
if extra.Method != "" {
query["method"] = extra.Method
}
if extra.Animated {
query["animated"] = "true"
}
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
Method: http.MethodGet,
URL: cli.BuildURLWithQuery(path, query),
DontReadResponse: true,
})
return resp, err
}
func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) {
resp, err := cli.Download(ctx, mxcURL)
if err != nil {
@ -1989,15 +1750,10 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr
}
req.MXC = resp.ContentURI
req.UnstableUploadURL = resp.UnstableUploadURL
if req.AsyncContext == nil {
req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background())
}
go func() {
_, err = cli.UploadMedia(req.AsyncContext, req)
_, err = cli.UploadMedia(ctx, req)
if err != nil {
zerolog.Ctx(req.AsyncContext).Err(err).
Stringer("mxc", req.MXC).
Msg("Async upload of media failed")
cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed")
}
}()
return resp, nil
@ -2033,7 +1789,6 @@ type ReqUploadMedia struct {
ContentType string
FileName string
AsyncContext context.Context
DoneCallback func()
// MXC specifies an existing MXC URI which doesn't have content yet to upload into.
@ -2046,10 +1801,7 @@ type ReqUploadMedia struct {
}
func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) {
cli.Log.Debug().
Str("url", url).
Int64("content_length", contentLength).
Msg("Uploading media to external URL")
cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content)
if err != nil {
return nil, err
@ -2098,16 +1850,8 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*
Msg("Error uploading media to external URL, not retrying")
return nil, err
}
backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries)
cli.Log.Warn().Err(err).
Str("url", data.UnstableUploadURL).
Int("retry_in_seconds", int(backoff.Seconds())).
cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err).
Msg("Error uploading media to external URL, retrying")
select {
case <-time.After(backoff):
case <-ctx.Done():
return nil, ctx.Err()
}
retries--
_, err = readerSeeker.Seek(0, io.SeekStart)
if err != nil {
@ -2687,15 +2431,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req
return err
}
func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error {
func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error {
urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err
}
func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error {
func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error {
urlPath := cli.BuildClientURL("v3", "delete_devices")
_, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err
}
@ -2704,7 +2448,7 @@ type UIACallback = func(*RespUserInteractive) interface{}
// UploadCrossSigningKeys uploads the given cross-signing keys to the server.
// Because the endpoint requires user-interactive authentication a callback must be provided that,
// given the UI auth parameters, produces the required result (or nil to end the flow).
func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error {
func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error {
content, err := cli.MakeFullRequest(ctx, FullRequest{
Method: http.MethodPost,
URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"),
@ -2786,61 +2530,24 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri
return err
}
// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin.
// BatchSend sends a batch of historical events into a room. This is only available for appservices.
//
// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) {
urlPath := cli.BuildClientURL("v3", "admin", "whois", userID)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
func (cli *Client) makeMSC4323URL(action string, target id.UserID) string {
if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) {
return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target)
} else if cli.SpecVersions.Supports(FeatureStableAccountModeration) {
return cli.BuildClientURL("v1", "admin", action, target)
// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead.
func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) {
path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"}
query := map[string]string{
"prev_event_id": req.PrevEventID.String(),
}
return ""
}
// GetSuspendedStatus uses MSC4323 to check if a user is suspended.
func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) {
urlPath := cli.makeMSC4323URL("suspend", userID)
if urlPath == "" {
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
if req.BeeperNewMessages {
query["com.beeper.new_messages"] = "true"
}
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
return
}
// GetLockStatus uses MSC4323 to check if a user is locked.
func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) {
urlPath := cli.makeMSC4323URL("lock", userID)
if urlPath == "" {
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
if req.BeeperMarkReadBy != "" {
query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String()
}
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
return
}
// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended.
func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) {
urlPath := cli.makeMSC4323URL("suspend", userID)
if urlPath == "" {
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
if len(req.BatchID) > 0 {
query["batch_id"] = req.BatchID.String()
}
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res)
return
}
// SetLockStatus uses MSC4323 to set whether a user account is locked.
func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) {
urlPath := cli.makeMSC4323URL("lock", userID)
if urlPath == "" {
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
}
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res)
_, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp)
return
}

View file

@ -1,158 +0,0 @@
// Copyright (c) 2026 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package mautrix_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) {
roomID := id.RoomID("!room:example.com")
evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
txnID := "txn-123"
var gotPath string
var gotQueryTS string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotQueryTS = r.URL.Query().Get("ts")
assert.Equal(t, http.MethodPut, r.Method)
_, _ = w.Write([]byte(`{"event_id":"$evt"}`))
}))
defer ts.Close()
cli, err := mautrix.NewClient(ts.URL, "", "")
require.NoError(t, err)
_, err = cli.BeeperSendEphemeralEvent(
context.Background(),
roomID,
evtType,
map[string]any{"foo": "bar"},
mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234},
)
require.NoError(t, err)
assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/"))
assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID))
assert.Equal(t, "1234", gotQueryTS)
}
func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`))
}))
defer ts.Close()
cli, err := mautrix.NewClient(ts.URL, "", "")
require.NoError(t, err)
_, err = cli.BeeperSendEphemeralEvent(
context.Background(),
id.RoomID("!room:example.com"),
event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType},
map[string]any{"foo": "bar"},
)
require.Error(t, err)
assert.True(t, errors.Is(err, mautrix.MUnrecognized))
}
func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) {
roomID := id.RoomID("!room:example.com")
evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
txnID := "txn-encrypted"
stateStore := mautrix.NewMemoryStateStore()
err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{
Algorithm: id.AlgorithmMegolmV1,
})
require.NoError(t, err)
fakeCrypto := &fakeCryptoHelper{
encryptedContent: &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1,
MegolmCiphertext: []byte("ciphertext"),
},
}
var gotPath string
var gotBody map[string]any
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
assert.Equal(t, http.MethodPut, r.Method)
err := json.NewDecoder(r.Body).Decode(&gotBody)
require.NoError(t, err)
_, _ = w.Write([]byte(`{"event_id":"$evt"}`))
}))
defer ts.Close()
cli, err := mautrix.NewClient(ts.URL, "", "")
require.NoError(t, err)
cli.StateStore = stateStore
cli.Crypto = fakeCrypto
_, err = cli.BeeperSendEphemeralEvent(
context.Background(),
roomID,
evtType,
map[string]any{"foo": "bar"},
mautrix.ReqSendEvent{TransactionID: txnID},
)
require.NoError(t, err)
assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID))
assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"])
assert.Equal(t, 1, fakeCrypto.encryptCalls)
assert.Equal(t, roomID, fakeCrypto.lastRoomID)
assert.Equal(t, evtType, fakeCrypto.lastEventType)
}
type fakeCryptoHelper struct {
encryptCalls int
lastRoomID id.RoomID
lastEventType event.Type
lastEncryptInput any
encryptedContent *event.EncryptedEventContent
}
func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) {
f.encryptCalls++
f.lastRoomID = roomID
f.lastEventType = eventType
f.lastEncryptInput = content
return f.encryptedContent, nil
}
func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) {
return nil, nil
}
func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool {
return false
}
func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) {
}
func (f *fakeCryptoHelper) Init(context.Context) error {
return nil
}

View file

@ -1,4 +1,4 @@
// Copyright (c) 2026 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,20 +8,14 @@ package commands
import (
"fmt"
"slices"
"strings"
"sync"
"go.mau.fi/util/exmaps"
"maunium.net/go/mautrix/event/cmdschema"
)
type CommandContainer[MetaType any] struct {
commands map[string]*Handler[MetaType]
aliases map[string]string
lock sync.RWMutex
parent *Handler[MetaType]
}
func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
@ -31,29 +25,6 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
}
}
func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent {
data := make(exmaps.Set[*Handler[MetaType]])
cont.collectHandlers(data)
specs := make([]*cmdschema.EventContent, 0, data.Size())
for handler := range data.Iter() {
if handler.Parameters != nil {
specs = append(specs, handler.Spec())
}
}
return specs
}
func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) {
cont.lock.RLock()
defer cont.lock.RUnlock()
for _, handler := range cont.commands {
into.Add(handler)
if handler.subcommandContainer != nil {
handler.subcommandContainer.collectHandlers(into)
}
}
}
// Register registers the given command handlers.
func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) {
if cont == nil {
@ -61,10 +32,7 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType])
}
cont.lock.Lock()
defer cont.lock.Unlock()
for i, handler := range handlers {
if handler == nil {
panic(fmt.Errorf("handler #%d is nil", i+1))
}
for _, handler := range handlers {
cont.registerOne(handler)
}
}
@ -77,10 +45,6 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType])
} else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists {
panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget))
}
if !slices.Contains(handler.parents, cont.parent) {
handler.parents = append(handler.parents, cont.parent)
handler.nestedNameCache = nil
}
cont.commands[handler.Name] = handler
for _, alias := range handler.Aliases {
if strings.ToLower(alias) != alias {

View file

@ -1,4 +1,4 @@
// Copyright (c) 2026 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,7 +8,6 @@ package commands
import (
"context"
"encoding/json"
"fmt"
"strings"
@ -36,8 +35,6 @@ type Event[MetaType any] struct {
// RawArgs is the same as args, but without the splitting by whitespace.
RawArgs string
StructuredArgs json.RawMessage
Ctx context.Context
Log *zerolog.Logger
Proc *Processor[MetaType]
@ -64,7 +61,7 @@ var IDHTMLParser = &format.HTMLParser{
}
// ParseEvent parses a message into a command event struct.
func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] {
func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" {
return nil
@ -73,34 +70,12 @@ func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Even
if content.Format == event.FormatHTML {
text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx))
}
if content.MSC4391BotCommand != nil {
if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 {
return nil
}
wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand)
wrapped.RawInput = text
return wrapped
}
if len(text) == 0 {
return nil
}
return RawTextToEvent[MetaType](ctx, evt, text)
}
func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] {
commandParts := strings.Split(content.Command, " ")
return &Event[MetaType]{
Event: evt,
// Fake a command and args to let the subcommand finder in Process work.
Command: commandParts[0],
Args: commandParts[1:],
Ctx: ctx,
Log: zerolog.Ctx(ctx),
StructuredArgs: content.Arguments,
}
}
func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] {
parts := strings.Fields(text)
if len(parts) == 0 {
@ -213,25 +188,3 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) {
evt.RawArgs = arg + " " + evt.RawArgs
evt.Args = append([]string{arg}, evt.Args...)
}
func (evt *Event[MetaType]) ParseArgs(into any) error {
return json.Unmarshal(evt.StructuredArgs, into)
}
func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) {
err = evt.ParseArgs(&into)
return
}
func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) {
return func(evt *Event[MetaType]) {
parsed, err := ParseArgs[T, MetaType](evt)
if err != nil {
evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct")
// TODO better error, usage info? deduplicate with Process
evt.Reply("Failed to parse arguments: %v", err)
return
}
fn(evt, parsed)
}
}

View file

@ -1,4 +1,4 @@
// Copyright (c) 2026 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,9 +8,6 @@ package commands
import (
"strings"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/event/cmdschema"
)
type Handler[MetaType any] struct {
@ -28,63 +25,12 @@ type Handler[MetaType any] struct {
// Event.ShiftArg will likely be useful for implementing such parameters.
PreFunc func(ce *Event[MetaType])
// Description is a short description of the command.
Description *event.ExtensibleTextContainer
// Parameters is a description of structured command parameters.
// If set, the StructuredArgs field of Event will be populated.
Parameters []*cmdschema.Parameter
TailParam string
parents []*Handler[MetaType]
nestedNameCache []string
subcommandContainer *CommandContainer[MetaType]
}
func (h *Handler[MetaType]) NestedNames() []string {
if h.nestedNameCache != nil {
return h.nestedNameCache
}
nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents))
for _, parent := range h.parents {
if parent == nil {
nestedNames = append(nestedNames, h.Name)
nestedNames = append(nestedNames, h.Aliases...)
} else {
for _, parentName := range parent.NestedNames() {
nestedNames = append(nestedNames, parentName+" "+h.Name)
for _, alias := range h.Aliases {
nestedNames = append(nestedNames, parentName+" "+alias)
}
}
}
}
h.nestedNameCache = nestedNames
return nestedNames
}
func (h *Handler[MetaType]) Spec() *cmdschema.EventContent {
names := h.NestedNames()
return &cmdschema.EventContent{
Command: names[0],
Aliases: names[1:],
Parameters: h.Parameters,
Description: h.Description,
TailParam: h.TailParam,
}
}
func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) {
if h.Parameters == nil {
h.Parameters = other.Parameters
h.TailParam = other.TailParam
}
h.Func = other.Func
}
func (h *Handler[MetaType]) initSubcommandContainer() {
if len(h.Subcommands) > 0 {
h.subcommandContainer = NewCommandContainer[MetaType]()
h.subcommandContainer.parent = h
h.subcommandContainer.Register(h.Subcommands...)
} else {
h.subcommandContainer = nil

View file

@ -1,4 +1,4 @@
// Copyright (c) 2026 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
case event.EventReaction:
parsed = proc.ParseReaction(ctx, evt)
case event.EventMessage:
parsed = proc.ParseEvent(ctx, evt)
parsed = ParseEvent[MetaType](ctx, evt)
}
if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) {
if parsed == nil || !proc.PreValidator.Validate(parsed) {
return
}
parsed.Proc = proc
@ -107,12 +107,6 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
break
}
}
if parsed.StructuredArgs != nil && len(parsed.Args) > 0 {
// TODO allow unknown command handlers to be called?
// The client sent MSC4391 data, but the target command wasn't found
log.Debug().Msg("Didn't find handler for MSC4391 command")
return
}
logWith := log.With().
Str("command", parsed.Command).
@ -122,31 +116,11 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
}
if proc.LogArgs {
logWith = logWith.Strs("args", parsed.Args)
if parsed.StructuredArgs != nil {
logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs)
}
}
log = logWith.Logger()
parsed.Ctx = log.WithContext(ctx)
parsed.Log = &log
if handler.Parameters != nil && parsed.StructuredArgs == nil {
// The handler wants structured parameters, but the client didn't send MSC4391 data
var err error
parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs)
if err != nil {
log.Debug().Err(err).Msg("Failed to parse structured arguments")
// TODO better error, usage info? deduplicate with WithParsedArgs
parsed.Reply("Failed to parse arguments: %v", err)
return
}
if proc.LogArgs {
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.RawJSON("structured_args", parsed.StructuredArgs)
})
}
}
log.Debug().Msg("Processing command")
handler.Func(parsed)
}

View file

@ -1,4 +1,4 @@
// Copyright (c) 2026 Tulir Asokan
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,7 +8,6 @@ package commands
import (
"context"
"encoding/json"
"strings"
"github.com/rs/zerolog"
@ -20,11 +19,6 @@ import (
const ReactionCommandsKey = "fi.mau.reaction_commands"
const ReactionMultiUseKey = "fi.mau.reaction_multi_use"
type ReactionCommandData struct {
Command string `json:"command"`
Args any `json:"args,omitempty"`
}
func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] {
content, ok := evt.Content.Parsed.(*event.ReactionEventContent)
if !ok {
@ -73,33 +67,21 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E
Msg("Reaction command not found in target event")
return nil
}
var wrappedEvt *Event[MetaType]
switch typedCmd := rawCmd.(type) {
case string:
wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd)
case map[string]any:
var input event.MSC4391BotCommandInput
if marshaled, err := json.Marshal(typedCmd); err != nil {
} else if err = json.Unmarshal(marshaled, &input); err != nil {
} else {
wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input)
}
}
if wrappedEvt == nil {
cmdString, ok := rawCmd.(string)
if !ok {
zerolog.Ctx(ctx).Debug().
Stringer("target_event_id", evtID).
Str("reaction_key", content.RelatesTo.Key).
Msg("Reaction command data is invalid")
return nil
}
wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString)
wrappedEvt.Proc = proc
wrappedEvt.Redact()
if !isMultiUse {
DeleteAllReactions(ctx, proc.Client, evt)
}
if wrappedEvt.Command == "" {
if cmdString == "" {
return nil
}
return wrappedEvt

View file

@ -21,24 +21,13 @@ import (
)
var (
ErrHashMismatch = errors.New("mismatching SHA-256 digest")
ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
ErrInvalidKey = errors.New("failed to decode key")
ErrInvalidInitVector = errors.New("failed to decode initialization vector")
ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
ErrReaderClosed = errors.New("encrypting reader was already closed")
)
// Deprecated: use variables prefixed with Err
var (
HashMismatch = ErrHashMismatch
UnsupportedVersion = ErrUnsupportedVersion
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
InvalidKey = ErrInvalidKey
InvalidInitVector = ErrInvalidInitVector
InvalidHash = ErrInvalidHash
ReaderClosed = ErrReaderClosed
HashMismatch = errors.New("mismatching SHA-256 digest")
UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
InvalidKey = errors.New("failed to decode key")
InvalidInitVector = errors.New("failed to decode initialization vector")
InvalidHash = errors.New("failed to decode SHA-256 hash")
ReaderClosed = errors.New("encrypting reader was already closed")
)
var (
@ -96,25 +85,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
if ef.decoded != nil {
return nil
} else if len(ef.Key.Key) != keyBase64Length {
return ErrInvalidKey
return InvalidKey
} else if len(ef.InitVector) != ivBase64Length {
return ErrInvalidInitVector
return InvalidInitVector
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
return ErrInvalidHash
return InvalidHash
}
ef.decoded = &decodedKeys{}
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
if err != nil {
return ErrInvalidKey
return InvalidKey
}
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
if err != nil {
return ErrInvalidInitVector
return InvalidInitVector
}
if includeHash {
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
if err != nil {
return ErrInvalidHash
return InvalidHash
}
}
return nil
@ -190,7 +179,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
if r.closed {
return 0, ErrReaderClosed
return 0, ReaderClosed
}
if offset != 0 || whence != io.SeekStart {
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
@ -211,7 +200,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
if r.closed {
return 0, ErrReaderClosed
return 0, ReaderClosed
} else if r.isDecrypting && r.file.decoded == nil {
if err = r.file.PrepareForDecryption(); err != nil {
return
@ -235,7 +224,7 @@ func (r *encryptingReader) Close() (err error) {
}
if r.isDecrypting {
if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
return ErrHashMismatch
return HashMismatch
}
} else {
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
@ -276,9 +265,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
func (ef *EncryptedFile) PrepareForDecryption() error {
if ef.Version != "v2" {
return ErrUnsupportedVersion
return UnsupportedVersion
} else if ef.Key.Algorithm != "A256CTR" {
return ErrUnsupportedAlgorithm
return UnsupportedAlgorithm
} else if err := ef.decodeKeys(true); err != nil {
return err
}
@ -292,7 +281,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
}
dataHash := sha256.Sum256(data)
if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
return ErrHashMismatch
return HashMismatch
}
utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
return nil

View file

@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
file := parseHelloWorld()
file.Version = "foo"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, ErrUnsupportedVersion)
assert.ErrorIs(t, err, UnsupportedVersion)
}
func TestUnsupportedAlgorithm(t *testing.T) {
file := parseHelloWorld()
file.Key.Algorithm = "bar"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
assert.ErrorIs(t, err, UnsupportedAlgorithm)
}
func TestHashMismatch(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, ErrHashMismatch)
assert.ErrorIs(t, err, HashMismatch)
}
func TestTooLongHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, ErrInvalidHash)
assert.ErrorIs(t, err, InvalidHash)
}
func TestTooShortHash(t *testing.T) {
file := parseHelloWorld()
file.Hashes.SHA256 = "5/Gy1JftyyQ"
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
assert.ErrorIs(t, err, ErrInvalidHash)
assert.ErrorIs(t, err, InvalidHash)
}

View file

@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
}
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
Master: masterKey,
SelfSigning: selfKey,
UserSigning: userKey,

View file

@ -20,20 +20,6 @@ type CrossSigningPublicKeysCache struct {
UserSigningKey id.Ed25519
}
func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) {
pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
if pubkeys != nil {
hasKeys = true
isVerified, err = mach.CryptoStore.IsKeySignedBy(
ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey,
)
if err != nil {
err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
}
}
return
}
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache {
if mach.crossSigningPubkeys != nil {
return mach.crossSigningPubkeys
@ -63,8 +49,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
if len(dbKeys) > 0 {
masterKey, ok := dbKeys[id.XSUsageMaster]
if ok {
selfSigning := dbKeys[id.XSUsageSelfSigning]
userSigning := dbKeys[id.XSUsageUserSigning]
selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
userSigning, _ := dbKeys[id.XSUsageUserSigning]
return &CrossSigningPublicKeysCache{
MasterKey: masterKey.Key,
SelfSigningKey: selfSigning.Key,

View file

@ -8,7 +8,6 @@ package crypto
import (
"context"
"errors"
"fmt"
"maunium.net/go/mautrix"
@ -72,46 +71,6 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex
}, passphrase)
}
func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error {
keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx)
if err != nil {
return fmt.Errorf("failed to get default SSSS key data: %w", err)
}
key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey)
if errors.Is(err, ssss.ErrUnverifiableKey) {
mach.machOrContextLog(ctx).Warn().
Str("key_id", keyID).
Msg("SSSS key is unverifiable, trying to use without verifying")
} else if err != nil {
return err
}
err = mach.FetchCrossSigningKeysFromSSSS(ctx, key)
if err != nil {
return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
}
err = mach.SignOwnDevice(ctx, mach.OwnIdentity())
if err != nil {
return fmt.Errorf("failed to sign own device: %w", err)
}
err = mach.SignOwnMasterKey(ctx)
if err != nil {
return fmt.Errorf("failed to sign own master key: %w", err)
}
return nil
}
func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) {
recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
if err != nil {
err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err)
} else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil {
err = fmt.Errorf("failed to sign own device: %w", err)
} else if err = mach.SignOwnMasterKey(ctx); err != nil {
err = fmt.Errorf("failed to sign own master key: %w", err)
}
return
}
// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys.
//
// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key
@ -138,12 +97,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
// Publish cross-signing keys
err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback)
if err != nil {
return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err)
return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err)
}
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
if err != nil {
return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
}
return key.RecoveryKey(), keysCache, nil

View file

@ -20,34 +20,36 @@ import (
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
log := log.With().Stringer("user_id", userID).Logger()
log := log.With().Str("user_id", userID.String()).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
for curKeyUsage, curKey := range currentKeys {
log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
// got a new key with the same usage as an existing key
for _, newKeyUsage := range userKeys.Usage {
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
log.Debug().
Int64("signature_count", count).
Msg("Dropped signatures made by old key as it has been replaced")
if currentKeys != nil {
for curKeyUsage, curKey := range currentKeys {
log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
// got a new key with the same usage as an existing key
for _, newKeyUsage := range userKeys.Usage {
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
log.Debug().
Int64("signature_count", count).
Msg("Dropped signatures made by old key as it has been replaced")
}
}
break
}
break
}
}
}
for _, key := range userKeys.Keys {
log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {

View file

@ -225,6 +225,13 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted)
}
if helper.client.SetAppServiceDeviceID {
err = helper.mach.ShareKeys(ctx, -1)
if err != nil {
return fmt.Errorf("failed to share keys: %w", err)
}
}
return nil
}
@ -261,24 +268,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
if !ok || len(device.Keys) == 0 {
if isShared {
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
} else {
helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
return nil
}
helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys")
err = helper.mach.ShareKeys(ctx, -1)
if err != nil {
return fmt.Errorf("failed to share keys: %w", err)
}
return nil
} else if !isShared {
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
}
if !isShared {
helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
} else {
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
return nil
}
return nil
}
var NoSessionFound = crypto.ErrNoSessionFound
var NoSessionFound = crypto.NoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
@ -297,14 +304,24 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even
ctx = log.WithContext(ctx)
decrypted, err := helper.Decrypt(ctx, evt)
if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" {
go helper.waitForSession(ctx, evt)
} else if err != nil {
if errors.Is(err, NoSessionFound) {
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = helper.Decrypt(ctx, evt)
} else {
go helper.waitLongerForSession(ctx, log, evt)
return
}
}
if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
} else {
helper.postDecrypt(ctx, decrypted)
return
}
helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
@ -345,33 +362,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
}
}
func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) {
log := zerolog.Ctx(ctx)
content := evt.Content.AsEncrypted()
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err := helper.Decrypt(ctx, evt)
if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
} else {
helper.postDecrypt(ctx, decrypted)
}
} else {
go helper.waitLongerForSession(ctx, evt)
}
}
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) {
log := zerolog.Ctx(ctx)
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
@ -419,7 +413,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
defer helper.lock.RUnlock()
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
if err != nil {
if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
return
}
helper.log.Debug().

View file

@ -24,23 +24,13 @@ import (
)
var (
ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
ErrRatchetError = errors.New("failed to ratchet session after use")
ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload")
)
// Deprecated: use variables prefixed with Err
var (
IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
NoSessionFound = ErrNoSessionFound
DuplicateMessageIndex = ErrDuplicateMessageIndex
WrongRoom = ErrWrongRoom
DeviceKeyMismatch = ErrDeviceKeyMismatch
RatchetError = ErrRatchetError
IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
DuplicateMessageIndex = errors.New("duplicate megolm message index")
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
RatchetError = errors.New("failed to ratchet session after use")
)
type megolmEvent struct {
@ -55,30 +45,13 @@ var (
relatesToTopLevelPath = exgjson.Path("content", "m.relates_to")
)
const sessionIDLength = 43
func validateCiphertextCharacters(ciphertext []byte) bool {
for _, b := range ciphertext {
if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' {
return false
}
}
return true
}
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, ErrIncorrectEncryptedContentType
return nil, IncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
return nil, ErrUnsupportedAlgorithm
} else if len(content.MegolmCiphertext) < 74 {
return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext))
} else if len(content.SessionID) != sessionIDLength {
return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID))
} else if !validateCiphertextCharacters(content.MegolmCiphertext) {
return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload)
return nil, UnsupportedAlgorithm
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
@ -124,13 +97,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
log.Debug().
Stringer("session_sender_key", sess.SenderKey).
Stringer("device_sender_key", device.IdentityKey).
Stringer("session_signing_key", sess.SigningKey).
Stringer("device_signing_key", device.SigningKey).
Msg("Device keys don't match keys in session, marking as untrusted")
trustLevel = id.TrustStateDeviceKeyMismatch
return nil, DeviceKeyMismatch
} else {
trustLevel, err = mach.ResolveTrustContext(ctx, device)
if err != nil {
@ -180,7 +147,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
if err != nil {
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
} else if megolmEvt.RoomID != encryptionRoomID {
return nil, ErrWrongRoom
return nil, WrongRoom
}
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
megolmEvt.Type.Class = event.StateEventType
@ -213,7 +180,6 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
TrustSource: device,
ForwardedKeys: forwardedKeys,
WasEncrypted: true,
EventSource: evt.Mautrix.EventSource | event.SourceDecrypted,
ReceivedAt: evt.Mautrix.ReceivedAt,
},
}, nil
@ -235,19 +201,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
if decodeErr != nil {
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
} else if !ok {
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
}
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
}
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
@ -258,11 +224,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
}
@ -270,7 +238,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
}
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
@ -322,24 +290,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, ErrRatchetError
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, ErrRatchetError
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, ErrRatchetError
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, ErrRatchetError
return sess, plaintext, messageIndex, RatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}

View file

@ -17,36 +17,21 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var (
ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
ErrSenderMismatch = errors.New("mismatched sender in olm payload")
ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
ErrDuplicateMessage = errors.New("duplicate olm message")
)
// Deprecated: use variables prefixed with Err
var (
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
NotEncryptedForMe = ErrNotEncryptedForMe
UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
SenderMismatch = ErrSenderMismatch
RecipientMismatch = ErrRecipientMismatch
RecipientKeyMismatch = ErrRecipientKeyMismatch
UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
UnsupportedOlmMessageType = errors.New("unsupported olm message type")
DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
SenderMismatch = errors.New("mismatched sender in olm payload")
RecipientMismatch = errors.New("mismatched recipient in olm payload")
RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
ErrDuplicateMessage = errors.New("duplicate olm message")
)
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
@ -68,13 +53,13 @@ type DecryptedOlmEvent struct {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, ErrIncorrectEncryptedContentType
return nil, IncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmOlmV1 {
return nil, ErrUnsupportedAlgorithm
return nil, UnsupportedAlgorithm
}
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
if !ok {
return nil, ErrNotEncryptedForMe
return nil, NotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
@ -90,7 +75,7 @@ type OlmEventKeys struct {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
return nil, ErrUnsupportedOlmMessageType
return nil, UnsupportedOlmMessageType
}
log := mach.machOrContextLog(ctx).With().
@ -114,11 +99,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
olmEvt.Type.Class = evt.Type.Class
if evt.Sender != olmEvt.Sender {
return nil, ErrSenderMismatch
return nil, SenderMismatch
} else if mach.Client.UserID != olmEvt.Recipient {
return nil, ErrRecipientMismatch
return nil, RecipientMismatch
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
return nil, ErrRecipientKeyMismatch
return nil, RecipientKeyMismatch
}
if len(olmEvt.Content.VeryRaw) > 0 {
@ -134,9 +119,6 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
}
func olmMessageHash(ciphertext string) ([32]byte, error) {
if ciphertext == "" {
return [32]byte{}, fmt.Errorf("empty ciphertext")
}
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
return sha256.Sum256(ciphertextBytes), err
}
@ -166,7 +148,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
if err != nil {
if err == ErrDecryptionFailedWithMatchingSession {
if err == DecryptionFailedWithMatchingSession {
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
@ -184,10 +166,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(log, sender, senderKey)
return nil, ErrDecryptionFailedForNormalMessage
return nil, DecryptionFailedForNormalMessage
}
accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
@ -199,7 +180,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
log.Debug().
Hex("ciphertext_hash", ciphertextHash[:]).
Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("olm_session_description", session.Describe()).
Msg("Created inbound olm session")
ctx = log.WithContext(ctx)
@ -208,19 +188,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
plaintext, err = session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
log.Debug().
Hex("ciphertext_hash", ciphertextHash[:]).
Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("ciphertext", ciphertext).
Str("olm_session_description", session.Describe()).
Msg("DEBUG: Failed to decrypt prekey olm message with newly created session")
err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup)
if err2 != nil {
log.Debug().Err(err2).Msg("Goolm confirmed decryption failure")
} else {
log.Warn().Msg("Goolm decryption was successful after libolm failure?")
}
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
}
@ -238,23 +205,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
return plaintext, nil
}
func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error {
acc, err := account.AccountFromPickled(accountBackup, []byte("tmp"))
if err != nil {
return fmt.Errorf("failed to unpickle olm account: %w", err)
}
sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext)
if err != nil {
return fmt.Errorf("failed to create inbound session: %w", err)
}
_, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey)
if err != nil {
// This is the expected result if libolm failed
return fmt.Errorf("failed to decrypt with new session: %w", err)
}
return nil
}
const MaxOlmSessionsPerDevice = 5
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
@ -313,11 +263,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
if err != nil {
log.Warn().Err(err).
Hex("ciphertext_hash", ciphertextHash[:]).
Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
return nil, ErrDecryptionFailedWithMatchingSession
return nil, DecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
@ -357,10 +306,10 @@ const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
log = log.With().Str("action", "unwedge olm session").Logger()
ctx := log.WithContext(mach.backgroundCtx)
ctx := log.WithContext(mach.BackgroundCtx)
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
delta := time.Since(prevUnwedge)
delta := time.Now().Sub(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
log.Debug().
Str("previous_recreation", delta.String()).
@ -391,10 +340,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
return
}
log.Debug().
Time("last_created", lastCreatedAt).
Stringer("device_id", deviceIdentity.DeviceID).
Msg("Creating new Olm session")
log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
mach.devicesToUnwedgeLock.Lock()
mach.devicesToUnwedge[senderKey] = true
mach.devicesToUnwedgeLock.Unlock()

View file

@ -22,23 +22,14 @@ import (
)
var (
ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
ErrInvalidKeySignature = errors.New("invalid signature on device keys")
ErrUserNotTracked = errors.New("user is not tracked")
)
MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
MismatchingSigningKey = errors.New("received update for device with different signing key")
NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
InvalidKeySignature = errors.New("invalid signature on device keys")
// Deprecated: use variables prefixed with Err
var (
MismatchingDeviceID = ErrMismatchingDeviceID
MismatchingUserID = ErrMismatchingUserID
MismatchingSigningKey = ErrMismatchingSigningKey
NoSigningKeyFound = ErrNoSigningKeyFound
NoIdentityKeyFound = ErrNoIdentityKeyFound
InvalidKeySignature = ErrInvalidKeySignature
ErrUserNotTracked = errors.New("user is not tracked")
)
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
@ -215,7 +206,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
for userID, devices := range resp.DeviceKeys {
log := log.With().Stringer("user_id", userID).Logger()
log := log.With().Str("user_id", userID.String()).Logger()
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
@ -231,7 +222,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
Msg("Updating devices in store")
changed := false
for deviceID, deviceKeys := range devices {
log := log.With().Stringer("device_id", deviceID).Logger()
log := log.With().Str("device_id", deviceID.String()).Logger()
existing, ok := existingDevices[deviceID]
if !ok {
// New device
@ -279,7 +270,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
}
}
for userID := range req.DeviceKeys {
log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user")
log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
}
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
@ -321,28 +312,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
if deviceID != deviceKeys.DeviceID {
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
} else if userID != deviceKeys.UserID {
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
}
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
if signingKey == "" {
return nil, ErrNoSigningKeyFound
return nil, NoSigningKeyFound
} else if identityKey == "" {
return nil, ErrNoIdentityKeyFound
return nil, NoIdentityKeyFound
}
if existing != nil && existing.SigningKey != signingKey {
return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
}
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
if err != nil {
return existing, fmt.Errorf("failed to verify signature: %w", err)
} else if !ok {
return existing, ErrInvalidKeySignature
return existing, InvalidKeySignature
}
name, ok := deviceKeys.Unsigned["device_display_name"].(string)

View file

@ -25,12 +25,7 @@ import (
)
var (
ErrNoGroupSession = errors.New("no group session created")
)
// Deprecated: use variables prefixed with Err
var (
NoGroupSession = ErrNoGroupSession
NoGroupSession = errors.New("no group session created")
)
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
@ -46,7 +41,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T {
return &result
}
func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
func getRelatesTo(content any) *event.RelatesTo {
contentJSON, ok := content.(json.RawMessage)
if ok {
return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to")
@ -59,7 +54,7 @@ func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
if ok {
return relatable.OptionalGetRelatesTo()
}
return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to")
return nil
}
func getMentions(content any) *event.Mentions {
@ -87,20 +82,15 @@ type rawMegolmEvent struct {
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
func IsShareError(err error) bool {
return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
}
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
if len(ciphertext) == 0 {
return 0, fmt.Errorf("empty ciphertext")
}
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
} else if len(decoded) < 2+binary.MaxVarintLen64 {
return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded))
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
@ -130,7 +120,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
return nil, ErrNoGroupSession
return nil, NoGroupSession
}
plaintext, err := json.Marshal(&rawMegolmEvent{
RoomID: roomID,
@ -168,21 +158,12 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
RelatesTo: getRelatesTo(content, plaintext),
RelatesTo: getRelatesTo(content),
// These are deprecated
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}
if mach.MSC4392Relations && encrypted.RelatesTo != nil {
// When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content.
// Other relations like threads are still left unencrypted.
encrypted.RelatesTo.InReplyTo = nil
encrypted.RelatesTo.IsFallingBack = false
if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" {
encrypted.RelatesTo = nil
}
}
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
@ -252,7 +233,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
var fetchKeysForUsers []id.UserID
for _, userID := range users {
log := log.With().Stringer("target_user_id", userID).Logger()
log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Err(err).Msg("Failed to get devices of user")
@ -324,7 +305,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
toDeviceWithheld.Messages[userID] = withheld
}
log := log.With().Stringer("target_user_id", userID).Logger()
log := log.With().Str("target_user_id", userID.String()).Logger()
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
log.Debug().
@ -370,19 +351,26 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
logUsers := zerolog.Dict()
for userID, sessions := range olmSessions {
if len(sessions) == 0 {
continue
}
logDevices := zerolog.Dict()
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
log.Trace().
Stringer("target_user_id", userID).
Stringer("target_device_id", deviceID).
Stringer("target_identity_key", device.identity.IdentityKey).
Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
logDevices.Str(string(deviceID), string(device.identity.IdentityKey))
deviceCount++
log.Debug().
Stringer("target_user_id", userID).
Stringer("target_device_id", deviceID).
Stringer("target_identity_key", device.identity.IdentityKey).
Msg("Encrypted group session for device")
if !mach.DisableSharedGroupSessionTracking {
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
if err != nil {
@ -396,13 +384,11 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
}
}
}
logUsers.Dict(string(userID), logDevices)
}
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
Dict("destination_map", logUsers).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
return err

View file

@ -96,19 +96,15 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
panic(err)
}
log := mach.machOrContextLog(ctx)
log.Debug().
Str("recipient_identity_key", recipient.IdentityKey.String()).
Str("olm_session_id", session.ID().String()).
Str("olm_session_description", session.Describe()).
Msg("Encrypting olm message")
msgType, ciphertext, err := session.Encrypt(plaintext)
if err != nil {
panic(err)
}
ciphertextStr := string(ciphertext)
ciphertextHash, _ := olmMessageHash(ciphertextStr)
log.Debug().
Stringer("event_type", evtType).
Str("recipient_identity_key", recipient.IdentityKey.String()).
Str("olm_session_id", session.ID().String()).
Str("olm_session_description", session.Describe()).
Hex("ciphertext_hash", ciphertextHash[:]).
Msg("Encrypted olm message")
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
@ -119,7 +115,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
OlmCiphertext: event.OlmCiphertexts{
recipient.IdentityKey: {
Type: msgType,
Body: ciphertextStr,
Body: string(ciphertext),
},
},
}

View file

@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
if err != nil {
return err
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
return err
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair

View file

@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
account, err := account.NewAccount()
assert.NoError(t, err)
err = account.Unpickle(pickled, pickleKey)
assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
assert.ErrorIs(t, err, olm.ErrBadVersion)
}
func TestLoopback(t *testing.T) {

View file

@ -10,7 +10,7 @@ import (
"maunium.net/go/mautrix/crypto/olm"
)
func Register() {
func init() {
olm.InitNewAccount = func() (olm.Account, error) {
return NewAccount()
}

View file

@ -53,7 +53,6 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
// SharedSecret returns the shared secret between the key pair and the given public key.
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
// Note: the standard library checks that the output is non-zero
return c.PrivateKey.SharedSecret(pubKey)
}

View file

@ -25,8 +25,6 @@ func TestCurve25519(t *testing.T) {
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
assert.NoError(t, err)
assert.Equal(t, fromPrivate, firstKeypair)
_, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength))
assert.Error(t, err)
}
func TestCurve25519Case1(t *testing.T) {

View file

@ -4,8 +4,7 @@ import (
"encoding/base64"
)
// These methods should only be used for raw byte operations, never with string conversion
// Deprecated: base64.RawStdEncoding should be used directly
func Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
@ -15,6 +14,7 @@ func Decode(input []byte) ([]byte, error) {
return decoded[:writtenBytes], nil
}
// Deprecated: base64.RawStdEncoding should be used directly
func Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)

View file

@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
}
}
if decrypted[0] != pickleVersion {
return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {

View file

@ -3,9 +3,6 @@ package message
import (
"bytes"
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/olm"
)
type Decoder struct {
@ -23,8 +20,6 @@ func (d *Decoder) ReadVarInt() (uint64, error) {
func (d *Decoder) ReadVarBytes() ([]byte, error) {
if n, err := d.ReadVarInt(); err != nil {
return nil, err
} else if n > uint64(d.Len()) {
return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available())
} else {
out := make([]byte, n)
_, err = d.Read(out)

View file

@ -2,12 +2,10 @@ package message
import (
"bytes"
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm/aessha2"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/olm"
)
const (
@ -38,9 +36,6 @@ func (r *GroupMessage) Decode(input []byte) (err error) {
if err != nil {
return
}
if r.Version != protocolVersion {
return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
}
for {
// Read Key

Some files were not shown because too many files have changed in this diff Show more