mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-15 06:45:51 +01:00
Compare commits
No commits in common. "main" and "v0.25.0" have entirely different histories.
231 changed files with 2291 additions and 12158 deletions
33
.github/workflows/go.yml
vendored
33
.github/workflows/go.yml
vendored
|
|
@ -10,12 +10,12 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
name: Lint (latest)
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.26"
|
||||
go-version: "1.24"
|
||||
cache: true
|
||||
|
||||
- name: Install libolm
|
||||
|
|
@ -24,7 +24,6 @@ jobs:
|
|||
- name: Install goimports
|
||||
run: |
|
||||
go install golang.org/x/tools/cmd/goimports@latest
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
export PATH="$HOME/go/bin:$PATH"
|
||||
|
||||
- name: Run pre-commit
|
||||
|
|
@ -35,14 +34,14 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ["1.25", "1.26"]
|
||||
name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm)
|
||||
go-version: ["1.24", "1.25"]
|
||||
name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm)
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
|
@ -61,28 +60,28 @@ jobs:
|
|||
- name: Test
|
||||
run: go test -json -v ./... 2>&1 | gotestfmt
|
||||
|
||||
- name: Test (jsonv2)
|
||||
env:
|
||||
GOEXPERIMENT: jsonv2
|
||||
run: go test -json -v ./... 2>&1 | gotestfmt
|
||||
|
||||
build-goolm:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ["1.25", "1.26"]
|
||||
name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm)
|
||||
go-version: ["1.24", "1.25"]
|
||||
name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm)
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Set up gotestfmt
|
||||
uses: GoTestTools/gotestfmt-action@v2
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
rm -rf crypto/libolm
|
||||
|
|
|
|||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
|
|
@ -17,7 +17,7 @@ jobs:
|
|||
lock-stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: dessant/lock-threads@v6
|
||||
- uses: dessant/lock-threads@v5
|
||||
id: lock
|
||||
with:
|
||||
issue-inactive-days: 90
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude_types: [markdown]
|
||||
|
|
@ -9,7 +9,7 @@ repos:
|
|||
- id: check-added-large-files
|
||||
|
||||
- repo: https://github.com/tekwizely/pre-commit-golang
|
||||
rev: v1.0.0-rc.4
|
||||
rev: v1.0.0-rc.1
|
||||
hooks:
|
||||
- id: go-imports-repo
|
||||
args:
|
||||
|
|
@ -18,7 +18,8 @@ repos:
|
|||
- "-w"
|
||||
- id: go-vet-repo-mod
|
||||
- id: go-mod-tidy
|
||||
- id: go-staticcheck-repo-mod
|
||||
# TODO enable this
|
||||
#- id: go-staticcheck-repo-mod
|
||||
|
||||
- repo: https://github.com/beeper/pre-commit-go
|
||||
rev: v0.4.2
|
||||
|
|
@ -26,4 +27,3 @@ repos:
|
|||
- id: prevent-literal-http-methods
|
||||
- id: zerolog-ban-global-log
|
||||
- id: zerolog-ban-msgf
|
||||
- id: zerolog-use-stringer
|
||||
|
|
|
|||
201
CHANGELOG.md
201
CHANGELOG.md
|
|
@ -1,202 +1,4 @@
|
|||
## v0.26.3 (2026-02-16)
|
||||
|
||||
* Bumped minimum Go version to 1.25.
|
||||
* *(client)* Added fields for sending [MSC4354] sticky events.
|
||||
* *(bridgev2)* Added automatic message request accepting when sending message.
|
||||
* *(mediaproxy)* Added support for federation thumbnail endpoint.
|
||||
* *(crypto/ssss)* Improved support for recovery keys with slightly broken
|
||||
metadata.
|
||||
* *(crypto)* Changed key import to call session received callback even for
|
||||
sessions that already exist in the database.
|
||||
* *(appservice)* Fixed building websocket URL accidentally using file path
|
||||
separators instead of always `/`.
|
||||
* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field.
|
||||
* *(client)* Fixed incorrect context usage in async uploads.
|
||||
* *(crypto)* Fixed panic when passing invalid input to megolm message index
|
||||
parser used for debugging.
|
||||
* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned
|
||||
up properly.
|
||||
|
||||
[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354
|
||||
|
||||
## v0.26.2 (2026-01-16)
|
||||
|
||||
* *(bridgev2)* Added chunked portal deletion to avoid database locks when
|
||||
deleting large portals.
|
||||
* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata
|
||||
as per [MSC4392].
|
||||
* *(bridgev2/login)* Added `default_value` for user input fields.
|
||||
* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested
|
||||
HTTP client settings and to reset active connections of the network connector.
|
||||
* *(bridgev2)* Added interface to let network connectors get the provisioning
|
||||
API HTTP router and add new endpoints.
|
||||
* *(event)* Added blurhash field to Beeper link preview objects.
|
||||
* *(event)* Added [MSC4391] support for bot commands.
|
||||
* *(event)* Dropped [MSC4332] support for bot commands.
|
||||
* *(client)* Changed media download methods to return an error if the provided
|
||||
MXC URI is empty.
|
||||
* *(client)* Stabilized support for [MSC4323].
|
||||
* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events.
|
||||
* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with
|
||||
a portal re-ID call.
|
||||
|
||||
[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391
|
||||
[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392
|
||||
|
||||
## v0.26.1 (2025-12-16)
|
||||
|
||||
* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return
|
||||
the mime type from the callback rather than in the return get media return
|
||||
value. The callback can now also redirect the caller to a different file.
|
||||
* *(federation)* Added join/knock/leave functions
|
||||
(thanks to [@nexy7574] in [#422]).
|
||||
* *(federation/eventauth)* Fixed various incorrect checks.
|
||||
* *(client)* Added backoff for retrying media uploads to external URLs
|
||||
(with MSC3870).
|
||||
* *(bridgev2/config)* Added support for overriding config fields using
|
||||
environment variables.
|
||||
* *(bridgev2/commands)* Added command to mute chat on remote network.
|
||||
* *(bridgev2)* Added interface for network connectors to redirect to a different
|
||||
user ID when handling an invite from Matrix.
|
||||
* *(bridgev2)* Added interface for signaling message request status of portals.
|
||||
* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag
|
||||
is set in chat info.
|
||||
* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if
|
||||
bridging the new one is successful.
|
||||
* *(bridgev2/mxmain)* Improved error message when trying to run bridge with
|
||||
pre-megabridge database when no database migration exists.
|
||||
* *(bridgev2)* Improved reliability of database migration when enabling split
|
||||
portals.
|
||||
* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats.
|
||||
* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public
|
||||
portal rooms.
|
||||
* *(bridgev2)* Stopped hardcoding room versions in favor of checking
|
||||
server capabilities to determine appropriate `/createRoom` parameters.
|
||||
|
||||
[#422]: https://github.com/mautrix/go/pull/422
|
||||
|
||||
## v0.26.0 (2025-11-16)
|
||||
|
||||
* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent`
|
||||
has been able to do the same for a while now.
|
||||
* *(client,federation)* Added size limits for responses to make it safer to send
|
||||
requests to untrusted servers.
|
||||
* *(client)* Added wrapper for `/admin/whois` client API
|
||||
(thanks to [@nexy7574] in [#411]).
|
||||
* *(synapseadmin)* Added `force_purge` option to DeleteRoom
|
||||
(thanks to [@nexy7574] in [#420]).
|
||||
* *(statestore)* Added saving join rules for rooms.
|
||||
* *(bridgev2)* Added optional automatic rollback of room state if bridging the
|
||||
change to the remote network fails.
|
||||
* *(bridgev2)* Added management room notices if transient disconnect state
|
||||
doesn't resolve within 3 minutes.
|
||||
* *(bridgev2)* Added interface to signal that certain participants couldn't be
|
||||
invited when creating a group.
|
||||
* *(bridgev2)* Added `select` type for user input fields in login.
|
||||
* *(bridgev2)* Added interface to let network connector customize personal
|
||||
filtering space.
|
||||
* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to
|
||||
other bots.
|
||||
* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever
|
||||
possible.
|
||||
* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names,
|
||||
and encrypted files.
|
||||
* *(bridgev2/commands)* Added command to resync a single portal.
|
||||
* *(bridgev2/commands)* Added create group command.
|
||||
* *(bridgev2/config)* Added option to limit maximum number of logins.
|
||||
* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room
|
||||
is public.
|
||||
* *(bridgev2/disappear)* Changed read receipt handling to only start
|
||||
disappearing timers for messages up to the read message (note: may not work in
|
||||
all cases if the read receipt points at an unknown event).
|
||||
* *(event/reply)* Changed plaintext reply fallback removal to only happen when
|
||||
an HTML reply fallback is removed successfully.
|
||||
* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run.
|
||||
* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages.
|
||||
* *(federation)* Fixed HTTP method for sending transactions
|
||||
(thanks to [@nexy7574] in [#426]).
|
||||
* *(federation)* Fixed response body being closed even when using `DontReadBody`
|
||||
parameter.
|
||||
* *(federation)* Fixed validating auth for requests with query params.
|
||||
* *(federation/eventauth)* Fixed typo causing restricted joins to not work.
|
||||
|
||||
[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169
|
||||
[#411]: github.com/mautrix/go/pull/411
|
||||
[#420]: github.com/mautrix/go/pull/420
|
||||
[#426]: github.com/mautrix/go/pull/426
|
||||
|
||||
## v0.25.2 (2025-10-16)
|
||||
|
||||
* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into
|
||||
`ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old
|
||||
behavior, but most users likely want the relaxed version, as there are real
|
||||
users whose user IDs aren't valid under the strict rules.
|
||||
* *(crypto)* Added helper methods for generating and verifying with recovery
|
||||
keys.
|
||||
* *(bridgev2/matrix)* Added config option to automatically generate a recovery
|
||||
key for the bridge bot and self-sign the bridge's device.
|
||||
* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode
|
||||
for encryption with standard servers like Synapse.
|
||||
* *(bridgev2)* Added optional support for implicit read receipts.
|
||||
* *(bridgev2)* Added interface for deleting chats on remote network.
|
||||
* *(bridgev2)* Added local enforcement of media duration and size limits.
|
||||
* *(bridgev2)* Extended event duration logging to log any event taking too long.
|
||||
* *(bridgev2)* Improved validation in group creation provisioning API.
|
||||
* *(event)* Added event type constant for poll end events.
|
||||
* *(client)* Added wrapper for searching user directory.
|
||||
* *(client)* Improved support for managing [MSC4140] delayed events.
|
||||
* *(crypto/helper)* Changed default sync handling to not block on waiting for
|
||||
decryption keys. On initial sync, keys won't be requested at all by default.
|
||||
* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1).
|
||||
* *(bridgev2)* Fixed various bugs with migrating to split portals.
|
||||
* *(event)* Fixed poll start events having incorrect null `m.relates_to`.
|
||||
* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling.
|
||||
* *(federation)* Fixed various bugs in event auth.
|
||||
|
||||
## v0.25.1 (2025-09-16)
|
||||
|
||||
* *(client)* Fixed HTTP method of delete devices API call
|
||||
(thanks to [@fmseals] in [#393]).
|
||||
* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints
|
||||
(thanks to [@nexy7574] in [#407]).
|
||||
* *(client)* Stabilized support for extensible profiles.
|
||||
* *(client)* Stabilized support for `state_after` in sync.
|
||||
* *(client)* Removed deprecated MSC2716 requests.
|
||||
* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if
|
||||
the content struct doesn't implement `Relatable`.
|
||||
* *(crypto)* Changed olm unwedging to ignore newly created sessions if they
|
||||
haven't been used successfully in either direction.
|
||||
* *(federation)* Added utilities for generating, parsing, validating and
|
||||
authorizing PDUs.
|
||||
* Note: the new PDU code depends on `GOEXPERIMENT=jsonv2`
|
||||
* *(event)* Added `is_animated` flag from [MSC4230] to file info.
|
||||
* *(event)* Added types for [MSC4332]: In-room bot commands.
|
||||
* *(event)* Added missing poll end event type for [MSC3381].
|
||||
* *(appservice)* Fixed URLs not being escaped properly when using unix socket
|
||||
for homeserver connections.
|
||||
* *(format)* Added more helpers for forming markdown links.
|
||||
* *(event,bridgev2)* Added support for Beeper's disappearing message state event.
|
||||
* *(bridgev2)* Redesigned group creation interface and added support in commands
|
||||
and provisioning API.
|
||||
* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to
|
||||
get an old event. The method is best effort only, as some configurations don't
|
||||
allow fetching old events.
|
||||
* *(bridgev2)* Added shared logic for provisioning that can be reused by the
|
||||
API, commands and other sources.
|
||||
* *(bridgev2)* Fixed mentions and URL previews not being copied over when
|
||||
caption and media are merged.
|
||||
* *(bridgev2)* Removed config option to change provisioning API prefix, which
|
||||
had already broken in the previous release.
|
||||
|
||||
[@fmseals]: https://github.com/fmseals
|
||||
[#393]: https://github.com/mautrix/go/pull/393
|
||||
[#407]: https://github.com/mautrix/go/pull/407
|
||||
[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381
|
||||
[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230
|
||||
[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323
|
||||
[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332
|
||||
|
||||
## v0.25.0 (2025-08-16)
|
||||
## v0.25.0 (unreleased)
|
||||
|
||||
* Bumped minimum Go version to 1.24.
|
||||
* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux
|
||||
|
|
@ -437,7 +239,6 @@
|
|||
[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156
|
||||
[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190
|
||||
[#288]: https://github.com/mautrix/go/pull/288
|
||||
[@onestacked]: https://github.com/onestacked
|
||||
|
||||
## v0.22.0 (2024-11-16)
|
||||
|
||||
|
|
|
|||
10
README.md
10
README.md
|
|
@ -1,9 +1,8 @@
|
|||
# mautrix-go
|
||||
[](https://pkg.go.dev/maunium.net/go/mautrix)
|
||||
|
||||
A Golang Matrix framework. Used by [gomuks](https://gomuks.app),
|
||||
[go-neb](https://github.com/matrix-org/go-neb),
|
||||
[mautrix-whatsapp](https://github.com/mautrix/whatsapp)
|
||||
A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks),
|
||||
[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp)
|
||||
and others.
|
||||
|
||||
Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net)
|
||||
|
|
@ -14,10 +13,9 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or
|
|||
In addition to the basic client API features the original project has, this framework also has:
|
||||
|
||||
* Appservice support (Intent API like mautrix-python, room state storage, etc)
|
||||
* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc)
|
||||
* End-to-end encryption support (incl. interactive SAS verification)
|
||||
* High-level module for building puppeting bridges
|
||||
* Partial federation module (making requests, PDU processing and event authorization)
|
||||
* A media proxy server which can be used to expose anything as a Matrix media repo
|
||||
* High-level module for building chat clients
|
||||
* Wrapper functions for the Synapse admin API
|
||||
* Structs for parsing event content
|
||||
* Helpers for parsing and generating Matrix HTML
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error {
|
|||
} else if as.hsURLForClient.Scheme == "" {
|
||||
as.hsURLForClient.Scheme = "https"
|
||||
}
|
||||
as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath()
|
||||
as.hsURLForClient.RawPath = parsedURL.EscapedPath()
|
||||
|
||||
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
||||
as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar}
|
||||
|
|
@ -360,7 +360,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
|
|||
AccessToken: as.Registration.AppToken,
|
||||
UserAgent: as.UserAgent,
|
||||
StateStore: as.StateStore,
|
||||
Log: as.Log.With().Stringer("as_user_id", userID).Logger(),
|
||||
Log: as.Log.With().Str("as_user_id", userID.String()).Logger(),
|
||||
Client: as.HTTPClient,
|
||||
DefaultHTTPRetries: as.DefaultHTTPRetries,
|
||||
SpecVersions: as.SpecVersions,
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
|
|||
}
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if errors.Is(err, event.ErrUnsupportedContentType) {
|
||||
log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event")
|
||||
log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event")
|
||||
} else if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("event_id", evt.ID.String()).
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
|
|||
}
|
||||
|
||||
func (intent *IntentAPI) Register(ctx context.Context) error {
|
||||
_, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{
|
||||
_, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{
|
||||
Username: intent.Localpart,
|
||||
Type: mautrix.AuthTypeAppservice,
|
||||
InhibitLogin: true,
|
||||
|
|
@ -214,31 +214,23 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any {
|
|||
}
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
|
||||
func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(ctx, roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...)
|
||||
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(ctx, roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
|
||||
return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...)
|
||||
}
|
||||
|
||||
// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead
|
||||
func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
|
||||
return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
|
||||
if err := intent.EnsureJoined(ctx, roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
|
||||
return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) {
|
||||
func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
|
||||
if eventType != event.StateMember || stateKey != string(intent.UserID) {
|
||||
if err := intent.EnsureJoined(ctx, roomID); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -247,12 +239,15 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e
|
|||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...)
|
||||
return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON)
|
||||
}
|
||||
|
||||
// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead
|
||||
func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
|
||||
return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
|
||||
if err := intent.EnsureJoined(ctx, roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts)
|
||||
return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
|
||||
|
|
@ -311,7 +306,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
|
|||
func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
|
||||
if intent.IsCustomPuppet || len(extraContent) > 0 {
|
||||
_, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
|
||||
return &mautrix.RespJoinRoom{RoomID: roomID}, err
|
||||
return &mautrix.RespJoinRoom{}, err
|
||||
}
|
||||
return intent.Client.JoinRoomByID(ctx, roomID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,10 +11,9 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
|
@ -56,7 +55,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest {
|
|||
var prefixMessage string
|
||||
for unwrappedErr != nil {
|
||||
errorData, jsonErr = json.Marshal(unwrappedErr)
|
||||
if len(errorData) > 2 && jsonErr == nil {
|
||||
if errorData != nil && len(errorData) > 2 && jsonErr == nil {
|
||||
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
|
||||
prefixMessage = strings.TrimRight(prefixMessage, ": ")
|
||||
break
|
||||
|
|
@ -293,16 +292,10 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error)
|
|||
as.Log.Debug().Msg("Ignoring non-text message from websocket")
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
as.Log.Debug().Err(err).Msg("Error reading data from websocket")
|
||||
stopFunc(parseCloseError(err))
|
||||
return
|
||||
}
|
||||
var msg WebsocketMessage
|
||||
err = json.Unmarshal(data, &msg)
|
||||
err = json.NewDecoder(reader).Decode(&msg)
|
||||
if err != nil {
|
||||
as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket")
|
||||
as.Log.Debug().Err(err).Msg("Error reading JSON from websocket")
|
||||
stopFunc(parseCloseError(err))
|
||||
return
|
||||
}
|
||||
|
|
@ -374,7 +367,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
|
|||
copiedURL := *as.hsURLForClient
|
||||
parsed = &copiedURL
|
||||
}
|
||||
parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
|
||||
parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
|
||||
if parsed.Scheme == "http" {
|
||||
parsed.Scheme = "ws"
|
||||
} else if parsed.Scheme == "https" {
|
||||
|
|
@ -419,7 +412,6 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn
|
|||
}
|
||||
})
|
||||
}
|
||||
ws.SetReadLimit(50 * 1024 * 1024)
|
||||
as.ws = ws
|
||||
as.StopWebsocket = stopFunc
|
||||
as.PrepareWebsocket()
|
||||
|
|
|
|||
|
|
@ -9,14 +9,11 @@ package bridgev2
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exhttp"
|
||||
"go.mau.fi/util/exsync"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
|
||||
|
|
@ -54,7 +51,6 @@ type Bridge struct {
|
|||
|
||||
Background bool
|
||||
ExternallyManagedDB bool
|
||||
stopping atomic.Bool
|
||||
|
||||
wakeupBackfillQueue chan struct{}
|
||||
stopBackfillQueue *exsync.Event
|
||||
|
|
@ -130,7 +126,6 @@ func (br *Bridge) Start(ctx context.Context) error {
|
|||
|
||||
func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error {
|
||||
br.Background = true
|
||||
br.stopping.Store(false)
|
||||
err := br.StartConnectors(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -166,7 +161,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
|
|||
case <-time.After(20 * time.Second):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
br.stopping.Store(true)
|
||||
return nil
|
||||
} else {
|
||||
br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode")
|
||||
|
|
@ -176,7 +170,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa
|
|||
|
||||
func (br *Bridge) StartConnectors(ctx context.Context) error {
|
||||
br.Log.Info().Msg("Starting bridge")
|
||||
br.stopping.Store(false)
|
||||
if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil {
|
||||
br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background())
|
||||
br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx)
|
||||
|
|
@ -189,11 +182,7 @@ func (br *Bridge) StartConnectors(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
if !br.Background {
|
||||
var postMigrate func()
|
||||
br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx)
|
||||
if postMigrate != nil {
|
||||
defer postMigrate()
|
||||
}
|
||||
br.didSplitPortals = br.MigrateToSplitPortals(ctx)
|
||||
}
|
||||
br.Log.Info().Msg("Starting Matrix connector")
|
||||
err := br.Matrix.Start(ctx)
|
||||
|
|
@ -282,64 +271,20 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps b
|
|||
Msg("Resent bridge info to all portals")
|
||||
}
|
||||
|
||||
func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) {
|
||||
func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool {
|
||||
log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" {
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
affected, err := br.DB.Portal.MigrateToSplitPortals(ctx)
|
||||
if err != nil {
|
||||
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals")
|
||||
os.Exit(31)
|
||||
return false, nil
|
||||
log.Err(err).Msg("Failed to migrate portals")
|
||||
return false
|
||||
}
|
||||
log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals")
|
||||
affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to fix parent portals after split portal migration")
|
||||
os.Exit(31)
|
||||
return false, nil
|
||||
}
|
||||
log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration")
|
||||
withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx)
|
||||
if err != nil {
|
||||
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate")
|
||||
os.Exit(31)
|
||||
return false, nil
|
||||
}
|
||||
var roomsToDelete []id.RoomID
|
||||
log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver")
|
||||
for _, portal := range withoutReceiver {
|
||||
if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil {
|
||||
log.Err(err).
|
||||
Str("portal_id", string(portal.ID)).
|
||||
Stringer("mxid", portal.MXID).
|
||||
Msg("Failed to delete portal database row that failed to migrate")
|
||||
} else if portal.MXID != "" {
|
||||
log.Debug().
|
||||
Str("portal_id", string(portal.ID)).
|
||||
Stringer("mxid", portal.MXID).
|
||||
Msg("Marked portal room for deletion from homeserver")
|
||||
roomsToDelete = append(roomsToDelete, portal.MXID)
|
||||
} else {
|
||||
log.Debug().
|
||||
Str("portal_id", string(portal.ID)).
|
||||
Msg("Deleted portal row with no Matrix room")
|
||||
}
|
||||
}
|
||||
br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true")
|
||||
log.Info().Msg("Finished split portal migration successfully")
|
||||
return affected > 0, func() {
|
||||
for _, roomID := range roomsToDelete {
|
||||
if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil {
|
||||
log.Err(err).
|
||||
Stringer("mxid", roomID).
|
||||
Msg("Failed to delete portal room that failed to migrate")
|
||||
}
|
||||
}
|
||||
log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate")
|
||||
}
|
||||
return affected > 0
|
||||
}
|
||||
|
||||
func (br *Bridge) StartLogins(ctx context.Context) error {
|
||||
|
|
@ -374,46 +319,6 @@ func (br *Bridge) StartLogins(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (br *Bridge) ResetNetworkConnections() {
|
||||
nrn, ok := br.Network.(NetworkResettingNetwork)
|
||||
if ok {
|
||||
br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections")
|
||||
nrn.ResetNetworkConnections()
|
||||
return
|
||||
}
|
||||
|
||||
br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually")
|
||||
for _, login := range br.GetAllCachedUserLogins() {
|
||||
login.Log.Debug().Msg("Disconnecting and recreating client for network reset")
|
||||
ctx := login.Log.WithContext(br.BackgroundCtx)
|
||||
login.Client.Disconnect()
|
||||
err := login.recreateClient(ctx)
|
||||
if err != nil {
|
||||
login.Log.Err(err).Msg("Failed to recreate client during network reset")
|
||||
login.BridgeState.Send(status.BridgeState{
|
||||
StateEvent: status.StateUnknownError,
|
||||
Error: "bridgev2-network-reset-fail",
|
||||
Info: map[string]any{"go_error": err.Error()},
|
||||
})
|
||||
} else {
|
||||
login.Client.Connect(ctx)
|
||||
}
|
||||
}
|
||||
br.Log.Info().Msg("Finished resetting all user logins")
|
||||
}
|
||||
|
||||
func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings {
|
||||
mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings)
|
||||
if ok {
|
||||
return mchs.GetHTTPClientSettings()
|
||||
}
|
||||
return exhttp.SensibleClientSettings
|
||||
}
|
||||
|
||||
func (br *Bridge) IsStopping() bool {
|
||||
return br.stopping.Load()
|
||||
}
|
||||
|
||||
func (br *Bridge) Stop() {
|
||||
br.stop(false, 0)
|
||||
}
|
||||
|
|
@ -424,7 +329,6 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) {
|
|||
|
||||
func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) {
|
||||
br.Log.Info().Msg("Shutting down bridge")
|
||||
br.stopping.Store(true)
|
||||
br.DisappearLoop.Stop()
|
||||
br.stopBackfillQueue.Set()
|
||||
br.Matrix.PreStop()
|
||||
|
|
|
|||
|
|
@ -34,12 +34,10 @@ type BackfillQueueConfig struct {
|
|||
MaxBatchesOverride map[string]int `yaml:"max_batches_override"`
|
||||
}
|
||||
|
||||
func (bqc *BackfillQueueConfig) GetOverride(names ...string) int {
|
||||
for _, name := range names {
|
||||
override, ok := bqc.MaxBatchesOverride[name]
|
||||
if ok {
|
||||
return override
|
||||
}
|
||||
func (bqc *BackfillQueueConfig) GetOverride(name string) int {
|
||||
override, ok := bqc.MaxBatchesOverride[name]
|
||||
if !ok {
|
||||
return bqc.MaxBatches
|
||||
}
|
||||
return bqc.MaxBatches
|
||||
return override
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ type Config struct {
|
|||
Encryption EncryptionConfig `yaml:"encryption"`
|
||||
Logging zeroconfig.Config `yaml:"logging"`
|
||||
|
||||
EnvConfigPrefix string `yaml:"env_config_prefix"`
|
||||
|
||||
ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"`
|
||||
}
|
||||
|
||||
|
|
@ -62,40 +60,36 @@ type CleanupOnLogouts struct {
|
|||
}
|
||||
|
||||
type BridgeConfig struct {
|
||||
CommandPrefix string `yaml:"command_prefix"`
|
||||
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
|
||||
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
|
||||
AsyncEvents bool `yaml:"async_events"`
|
||||
SplitPortals bool `yaml:"split_portals"`
|
||||
ResendBridgeInfo bool `yaml:"resend_bridge_info"`
|
||||
NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
|
||||
BridgeStatusNotices string `yaml:"bridge_status_notices"`
|
||||
UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
|
||||
UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"`
|
||||
BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
|
||||
BridgeNotices bool `yaml:"bridge_notices"`
|
||||
TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
|
||||
OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
|
||||
MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
|
||||
DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
|
||||
CrossRoomReplies bool `yaml:"cross_room_replies"`
|
||||
OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
|
||||
RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"`
|
||||
KickMatrixUsers bool `yaml:"kick_matrix_users"`
|
||||
CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
|
||||
Relay RelayConfig `yaml:"relay"`
|
||||
Permissions PermissionConfig `yaml:"permissions"`
|
||||
Backfill BackfillConfig `yaml:"backfill"`
|
||||
CommandPrefix string `yaml:"command_prefix"`
|
||||
PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"`
|
||||
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
|
||||
AsyncEvents bool `yaml:"async_events"`
|
||||
SplitPortals bool `yaml:"split_portals"`
|
||||
ResendBridgeInfo bool `yaml:"resend_bridge_info"`
|
||||
NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"`
|
||||
BridgeStatusNotices string `yaml:"bridge_status_notices"`
|
||||
UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"`
|
||||
BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"`
|
||||
BridgeNotices bool `yaml:"bridge_notices"`
|
||||
TagOnlyOnCreate bool `yaml:"tag_only_on_create"`
|
||||
OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"`
|
||||
MuteOnlyOnCreate bool `yaml:"mute_only_on_create"`
|
||||
DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"`
|
||||
CrossRoomReplies bool `yaml:"cross_room_replies"`
|
||||
OutgoingMessageReID bool `yaml:"outgoing_message_re_id"`
|
||||
CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"`
|
||||
Relay RelayConfig `yaml:"relay"`
|
||||
Permissions PermissionConfig `yaml:"permissions"`
|
||||
Backfill BackfillConfig `yaml:"backfill"`
|
||||
}
|
||||
|
||||
type MatrixConfig struct {
|
||||
MessageStatusEvents bool `yaml:"message_status_events"`
|
||||
DeliveryReceipts bool `yaml:"delivery_receipts"`
|
||||
MessageErrorNotices bool `yaml:"message_error_notices"`
|
||||
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
|
||||
FederateRooms bool `yaml:"federate_rooms"`
|
||||
UploadFileThreshold int64 `yaml:"upload_file_threshold"`
|
||||
GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"`
|
||||
MessageStatusEvents bool `yaml:"message_status_events"`
|
||||
DeliveryReceipts bool `yaml:"delivery_receipts"`
|
||||
MessageErrorNotices bool `yaml:"message_error_notices"`
|
||||
SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
|
||||
FederateRooms bool `yaml:"federate_rooms"`
|
||||
UploadFileThreshold int64 `yaml:"upload_file_threshold"`
|
||||
}
|
||||
|
||||
type AnalyticsConfig struct {
|
||||
|
|
@ -105,6 +99,7 @@ type AnalyticsConfig struct {
|
|||
}
|
||||
|
||||
type ProvisioningConfig struct {
|
||||
Prefix string `yaml:"prefix"`
|
||||
SharedSecret string `yaml:"shared_secret"`
|
||||
DebugEndpoints bool `yaml:"debug_endpoints"`
|
||||
EnableSessionTransfers bool `yaml:"enable_session_transfers"`
|
||||
|
|
@ -117,12 +112,10 @@ type DirectMediaConfig struct {
|
|||
}
|
||||
|
||||
type PublicMediaConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
SigningKey string `yaml:"signing_key"`
|
||||
Expiry int `yaml:"expiry"`
|
||||
HashLength int `yaml:"hash_length"`
|
||||
PathPrefix string `yaml:"path_prefix"`
|
||||
UseDatabase bool `yaml:"use_database"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
SigningKey string `yaml:"signing_key"`
|
||||
HashLength int `yaml:"hash_length"`
|
||||
Expiry int `yaml:"expiry"`
|
||||
}
|
||||
|
||||
type DoublePuppetConfig struct {
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ type EncryptionConfig struct {
|
|||
Require bool `yaml:"require"`
|
||||
Appservice bool `yaml:"appservice"`
|
||||
MSC4190 bool `yaml:"msc4190"`
|
||||
MSC4392 bool `yaml:"msc4392"`
|
||||
SelfSign bool `yaml:"self_sign"`
|
||||
|
||||
PlaintextMentions bool `yaml:"plaintext_mentions"`
|
||||
|
||||
|
|
|
|||
|
|
@ -133,7 +133,9 @@ func doMigrateLegacy(helper up.Helper, python bool) {
|
|||
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"})
|
||||
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"})
|
||||
|
||||
CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
|
||||
CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
|
||||
CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"})
|
||||
CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"})
|
||||
CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"})
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ type Permissions struct {
|
|||
DoublePuppet bool `yaml:"double_puppet"`
|
||||
Admin bool `yaml:"admin"`
|
||||
ManageRelay bool `yaml:"manage_relay"`
|
||||
MaxLogins int `yaml:"max_logins"`
|
||||
}
|
||||
|
||||
type PermissionConfig map[string]*Permissions
|
||||
|
|
@ -41,7 +40,10 @@ func (pc PermissionConfig) IsConfigured() bool {
|
|||
_, hasExampleDomain := pc["example.com"]
|
||||
_, hasExampleUser := pc["@admin:example.com"]
|
||||
exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain)
|
||||
return len(pc) > exampleLen
|
||||
if len(pc) <= exampleLen {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (pc PermissionConfig) Get(userID id.UserID) Permissions {
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ func doUpgrade(helper up.Helper) {
|
|||
helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key")
|
||||
helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices")
|
||||
helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect")
|
||||
helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects")
|
||||
helper.Copy(up.Bool, "bridge", "bridge_matrix_leave")
|
||||
helper.Copy(up.Bool, "bridge", "bridge_notices")
|
||||
helper.Copy(up.Bool, "bridge", "tag_only_on_create")
|
||||
|
|
@ -41,8 +40,6 @@ func doUpgrade(helper up.Helper) {
|
|||
helper.Copy(up.Bool, "bridge", "mute_only_on_create")
|
||||
helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages")
|
||||
helper.Copy(up.Bool, "bridge", "cross_room_replies")
|
||||
helper.Copy(up.Bool, "bridge", "revert_failed_state_changes")
|
||||
helper.Copy(up.Bool, "bridge", "kick_matrix_users")
|
||||
helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled")
|
||||
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private")
|
||||
helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed")
|
||||
|
|
@ -101,12 +98,12 @@ func doUpgrade(helper up.Helper) {
|
|||
helper.Copy(up.Bool, "matrix", "sync_direct_chat_list")
|
||||
helper.Copy(up.Bool, "matrix", "federate_rooms")
|
||||
helper.Copy(up.Int, "matrix", "upload_file_threshold")
|
||||
helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info")
|
||||
|
||||
helper.Copy(up.Str|up.Null, "analytics", "token")
|
||||
helper.Copy(up.Str|up.Null, "analytics", "url")
|
||||
helper.Copy(up.Str|up.Null, "analytics", "user_id")
|
||||
|
||||
helper.Copy(up.Str, "provisioning", "prefix")
|
||||
if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" {
|
||||
sharedSecret := random.String(64)
|
||||
helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret")
|
||||
|
|
@ -136,8 +133,6 @@ func doUpgrade(helper up.Helper) {
|
|||
}
|
||||
helper.Copy(up.Int, "public_media", "expiry")
|
||||
helper.Copy(up.Int, "public_media", "hash_length")
|
||||
helper.Copy(up.Str|up.Null, "public_media", "path_prefix")
|
||||
helper.Copy(up.Bool, "public_media", "use_database")
|
||||
|
||||
helper.Copy(up.Bool, "backfill", "enabled")
|
||||
helper.Copy(up.Int, "backfill", "max_initial_messages")
|
||||
|
|
@ -163,8 +158,6 @@ func doUpgrade(helper up.Helper) {
|
|||
} else {
|
||||
helper.Copy(up.Bool, "encryption", "msc4190")
|
||||
}
|
||||
helper.Copy(up.Bool, "encryption", "msc4392")
|
||||
helper.Copy(up.Bool, "encryption", "self_sign")
|
||||
helper.Copy(up.Bool, "encryption", "allow_key_sharing")
|
||||
if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" {
|
||||
helper.Set(up.Str, random.String(64), "encryption", "pickle_key")
|
||||
|
|
@ -187,8 +180,6 @@ func doUpgrade(helper up.Helper) {
|
|||
helper.Copy(up.Int, "encryption", "rotation", "messages")
|
||||
helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation")
|
||||
|
||||
helper.Copy(up.Str|up.Null, "env_config_prefix")
|
||||
|
||||
helper.Copy(up.Map, "logging")
|
||||
}
|
||||
|
||||
|
|
@ -216,7 +207,6 @@ var SpacedBlocks = [][]string{
|
|||
{"backfill"},
|
||||
{"double_puppet"},
|
||||
{"encryption"},
|
||||
{"env_config_prefix"},
|
||||
{"logging"},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,15 +15,12 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exfmt"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/status"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
)
|
||||
|
||||
var CatchBridgeStateQueuePanics = true
|
||||
|
||||
type BridgeStateQueue struct {
|
||||
prevUnsent *status.BridgeState
|
||||
prevSent *status.BridgeState
|
||||
|
|
@ -32,13 +29,8 @@ type BridgeStateQueue struct {
|
|||
bridge *Bridge
|
||||
login *UserLogin
|
||||
|
||||
firstTransientDisconnect time.Time
|
||||
cancelScheduledNotice atomic.Pointer[context.CancelFunc]
|
||||
|
||||
stopChan chan struct{}
|
||||
stopReconnect atomic.Pointer[context.CancelFunc]
|
||||
|
||||
unknownErrorReconnects int
|
||||
}
|
||||
|
||||
func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) {
|
||||
|
|
@ -82,63 +74,31 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() {
|
|||
if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil {
|
||||
(*cancelFn)()
|
||||
}
|
||||
if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil {
|
||||
(*cancelFn)()
|
||||
}
|
||||
}
|
||||
|
||||
func (bsq *BridgeStateQueue) loop() {
|
||||
if CatchBridgeStateQueuePanics {
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
bsq.login.Log.Error().
|
||||
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
|
||||
Any(zerolog.ErrorFieldName, err).
|
||||
Msg("Panic in bridge state loop")
|
||||
}
|
||||
}()
|
||||
}
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil {
|
||||
bsq.login.Log.Error().
|
||||
Bytes(zerolog.ErrorStackFieldName, debug.Stack()).
|
||||
Any(zerolog.ErrorFieldName, err).
|
||||
Msg("Panic in bridge state loop")
|
||||
}
|
||||
}()
|
||||
for state := range bsq.ch {
|
||||
bsq.immediateSendBridgeState(state)
|
||||
}
|
||||
}
|
||||
|
||||
func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) {
|
||||
log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger()
|
||||
ctx := log.WithContext(bsq.bridge.BackgroundCtx)
|
||||
if !bsq.waitForTransientDisconnectReconnect(ctx) {
|
||||
return
|
||||
}
|
||||
prevUnsent := bsq.GetPrevUnsent()
|
||||
prev := bsq.GetPrev()
|
||||
if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent ||
|
||||
prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect {
|
||||
log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice")
|
||||
return
|
||||
}
|
||||
log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice")
|
||||
bsq.sendNotice(ctx, triggeredBy, true)
|
||||
}
|
||||
|
||||
func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) {
|
||||
func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) {
|
||||
noticeConfig := bsq.bridge.Config.BridgeStatusNotices
|
||||
isError := state.StateEvent == status.StateBadCredentials ||
|
||||
state.StateEvent == status.StateUnknownError ||
|
||||
state.UserAction == status.UserActionOpenNative ||
|
||||
(isDelayed && state.StateEvent == status.StateTransientDisconnect)
|
||||
state.UserAction == status.UserActionOpenNative
|
||||
sendNotice := noticeConfig == "all" || (noticeConfig == "errors" &&
|
||||
(isError || (bsq.errorSent && state.StateEvent == status.StateConnected)))
|
||||
if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError {
|
||||
bsq.firstTransientDisconnect = time.Time{}
|
||||
}
|
||||
if !sendNotice {
|
||||
if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect {
|
||||
if bsq.firstTransientDisconnect.IsZero() {
|
||||
bsq.firstTransientDisconnect = time.Now()
|
||||
}
|
||||
go bsq.scheduleNotice(state)
|
||||
}
|
||||
return
|
||||
}
|
||||
managementRoom, err := bsq.login.User.GetManagementRoom(ctx)
|
||||
|
|
@ -154,9 +114,6 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge
|
|||
if state.Error != "" {
|
||||
message += fmt.Sprintf(" (`%s`)", state.Error)
|
||||
}
|
||||
if isDelayed {
|
||||
message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay))
|
||||
}
|
||||
if state.Message != "" {
|
||||
message += fmt.Sprintf(": %s", state.Message)
|
||||
}
|
||||
|
|
@ -194,14 +151,8 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat
|
|||
} else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError {
|
||||
log.Debug().Msg("Not reconnecting as the previous state was not an unknown error")
|
||||
return
|
||||
} else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects {
|
||||
log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached")
|
||||
return
|
||||
}
|
||||
bsq.unknownErrorReconnects++
|
||||
log.Info().
|
||||
Int("reconnect_num", bsq.unknownErrorReconnects).
|
||||
Msg("Disconnecting and reconnecting login due to unknown error")
|
||||
log.Info().Msg("Disconnecting and reconnecting login due to unknown error")
|
||||
bsq.login.Disconnect()
|
||||
log.Debug().Msg("Disconnection finished, recreating client and reconnecting")
|
||||
err := bsq.login.recreateClient(ctx)
|
||||
|
|
@ -220,30 +171,14 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b
|
|||
return false
|
||||
}
|
||||
reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2))
|
||||
return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect)
|
||||
}
|
||||
|
||||
const TransientDisconnectNoticeDelay = 3 * time.Minute
|
||||
|
||||
func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool {
|
||||
timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay))
|
||||
zerolog.Ctx(ctx).Trace().
|
||||
Stringer("duration", timeUntilSchedule).
|
||||
Msg("Waiting before sending notice about transient disconnect")
|
||||
return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice)
|
||||
}
|
||||
|
||||
func (bsq *BridgeStateQueue) waitForReconnect(
|
||||
ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc],
|
||||
) bool {
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
if oldCancel := ptr.Swap(&cancel); oldCancel != nil {
|
||||
if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil {
|
||||
(*oldCancel)()
|
||||
}
|
||||
select {
|
||||
case <-time.After(reconnectIn):
|
||||
return ptr.CompareAndSwap(&cancel, nil)
|
||||
return bsq.stopReconnect.CompareAndSwap(&cancel, nil)
|
||||
case <-cancelCtx.Done():
|
||||
return false
|
||||
case <-bsq.stopChan:
|
||||
|
|
@ -263,7 +198,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState)
|
|||
}
|
||||
|
||||
ctx := bsq.login.Log.WithContext(context.Background())
|
||||
bsq.sendNotice(ctx, state, false)
|
||||
bsq.sendNotice(ctx, state)
|
||||
|
||||
retryIn := 2
|
||||
for {
|
||||
|
|
|
|||
|
|
@ -7,13 +7,10 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
)
|
||||
|
||||
var CommandRegisterPush = &FullHandler{
|
||||
|
|
@ -62,64 +59,3 @@ var CommandRegisterPush = &FullHandler{
|
|||
RequiresLogin: true,
|
||||
NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI],
|
||||
}
|
||||
|
||||
var CommandSendAccountData = &FullHandler{
|
||||
Func: func(ce *Event) {
|
||||
if len(ce.Args) < 2 {
|
||||
ce.Reply("Usage: `$cmdprefix debug-account-data <type> <content>")
|
||||
return
|
||||
}
|
||||
var content event.Content
|
||||
evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType}
|
||||
ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0]))
|
||||
err := json.Unmarshal([]byte(ce.RawArgs), &content)
|
||||
if err != nil {
|
||||
ce.Reply("Failed to parse JSON: %v", err)
|
||||
return
|
||||
}
|
||||
err = content.ParseRaw(evtType)
|
||||
if err != nil {
|
||||
ce.Reply("Failed to deserialize content: %v", err)
|
||||
return
|
||||
}
|
||||
res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{
|
||||
Sender: ce.User.MXID,
|
||||
Type: evtType,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
RoomID: ce.RoomID,
|
||||
Content: content,
|
||||
})
|
||||
ce.Reply("Result: %+v", res)
|
||||
},
|
||||
Name: "debug-account-data",
|
||||
Help: HelpMeta{
|
||||
Section: HelpSectionAdmin,
|
||||
Description: "Send a room account data event to the bridge",
|
||||
Args: "<_type_> <_content_>",
|
||||
},
|
||||
RequiresAdmin: true,
|
||||
RequiresPortal: true,
|
||||
RequiresLogin: true,
|
||||
}
|
||||
|
||||
var CommandResetNetwork = &FullHandler{
|
||||
Func: func(ce *Event) {
|
||||
if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") {
|
||||
nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork)
|
||||
if ok {
|
||||
nrn.ResetHTTPTransport()
|
||||
} else {
|
||||
ce.Reply("Network connector does not support resetting HTTP transport")
|
||||
}
|
||||
}
|
||||
ce.Bridge.ResetNetworkConnections()
|
||||
ce.React("✅️")
|
||||
},
|
||||
Name: "debug-reset-network",
|
||||
Help: HelpMeta{
|
||||
Section: HelpSectionAdmin,
|
||||
Description: "Reset network connections to the remote network",
|
||||
Args: "[--reset-transport]",
|
||||
},
|
||||
RequiresAdmin: true,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,15 +70,6 @@ func fnLogin(ce *Event) {
|
|||
}
|
||||
ce.Args = ce.Args[1:]
|
||||
}
|
||||
if reauth == nil && ce.User.HasTooManyLogins() {
|
||||
ce.Reply(
|
||||
"You have reached the maximum number of logins (%d). "+
|
||||
"Please logout from an existing login before creating a new one. "+
|
||||
"If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.",
|
||||
ce.User.Permissions.MaxLogins,
|
||||
)
|
||||
return
|
||||
}
|
||||
flows := ce.Bridge.Network.GetLoginFlows()
|
||||
var chosenFlowID string
|
||||
if len(ce.Args) > 0 {
|
||||
|
|
@ -121,7 +112,6 @@ func fnLogin(ce *Event) {
|
|||
ce.Reply("Failed to start login: %v", err)
|
||||
return
|
||||
}
|
||||
ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process")
|
||||
|
||||
nextStep = checkLoginCommandDirectParams(ce, login, nextStep)
|
||||
if nextStep != nil {
|
||||
|
|
@ -200,14 +190,11 @@ type userInputLoginCommandState struct {
|
|||
|
||||
func (uilcs *userInputLoginCommandState) promptNext(ce *Event) {
|
||||
field := uilcs.RemainingFields[0]
|
||||
parts := []string{fmt.Sprintf("Please enter your %s", field.Name)}
|
||||
if field.Description != "" {
|
||||
parts = append(parts, field.Description)
|
||||
ce.Reply("Please enter your %s\n%s", field.Name, field.Description)
|
||||
} else {
|
||||
ce.Reply("Please enter your %s", field.Name)
|
||||
}
|
||||
if len(field.Options) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `")))
|
||||
}
|
||||
ce.Reply(strings.Join(parts, "\n"))
|
||||
StoreCommandState(ce.User, &CommandState{
|
||||
Next: MinimalCommandHandlerFunc(uilcs.submitNext),
|
||||
Action: "Login",
|
||||
|
|
@ -252,19 +239,14 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
|
|||
return fmt.Errorf("failed to upload image: %w", err)
|
||||
}
|
||||
content := &event.MessageEventContent{
|
||||
MsgType: event.MsgImage,
|
||||
FileName: "qr.png",
|
||||
URL: qrMXC,
|
||||
File: qrFile,
|
||||
MsgType: event.MsgImage,
|
||||
FileName: "qr.png",
|
||||
URL: qrMXC,
|
||||
File: qrFile,
|
||||
|
||||
Body: qr,
|
||||
Format: event.FormatHTML,
|
||||
FormattedBody: fmt.Sprintf("<pre><code>%s</code></pre>", html.EscapeString(qr)),
|
||||
Info: &event.FileInfo{
|
||||
MimeType: "image/png",
|
||||
Width: qrSizePx,
|
||||
Height: qrSizePx,
|
||||
Size: len(qrData),
|
||||
},
|
||||
}
|
||||
if *prevEventID != "" {
|
||||
content.SetEdit(*prevEventID)
|
||||
|
|
@ -279,36 +261,6 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error {
|
||||
for _, att := range atts {
|
||||
if att.FileName == "" {
|
||||
return fmt.Errorf("missing attachment filename")
|
||||
}
|
||||
mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err)
|
||||
}
|
||||
content := &event.MessageEventContent{
|
||||
MsgType: att.Type,
|
||||
FileName: att.FileName,
|
||||
URL: mxc,
|
||||
File: file,
|
||||
Info: &event.FileInfo{
|
||||
MimeType: att.Info.MimeType,
|
||||
Width: att.Info.Width,
|
||||
Height: att.Info.Height,
|
||||
Size: att.Info.Size,
|
||||
},
|
||||
Body: att.FileName,
|
||||
}
|
||||
_, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
|
|
@ -500,7 +452,6 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string {
|
|||
}
|
||||
|
||||
func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) {
|
||||
ce.Log.Debug().Any("next_step", step).Msg("Got next login step")
|
||||
if step.Instructions != "" {
|
||||
ce.Reply(step.Instructions)
|
||||
}
|
||||
|
|
@ -515,10 +466,6 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte
|
|||
Override: override,
|
||||
}).prompt(ce)
|
||||
case bridgev2.LoginStepTypeUserInput:
|
||||
err := sendUserInputAttachments(ce, step.UserInputParams.Attachments)
|
||||
if err != nil {
|
||||
ce.Reply("Failed to send attachments: %v", err)
|
||||
}
|
||||
(&userInputLoginCommandState{
|
||||
Login: login.(bridgev2.LoginProcessUserInput),
|
||||
RemainingFields: step.UserInputParams.Fields,
|
||||
|
|
|
|||
|
|
@ -41,11 +41,10 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor {
|
|||
}
|
||||
proc.AddHandlers(
|
||||
CommandHelp, CommandCancel,
|
||||
CommandRegisterPush, CommandSendAccountData, CommandResetNetwork,
|
||||
CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
|
||||
CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom,
|
||||
CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin,
|
||||
CommandSetRelay, CommandUnsetRelay,
|
||||
CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute,
|
||||
CommandResolveIdentifier, CommandStartChat, CommandSearch,
|
||||
CommandSudo, CommandDoIn,
|
||||
)
|
||||
return proc
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) {
|
|||
}
|
||||
onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly
|
||||
var relay *bridgev2.UserLogin
|
||||
if len(ce.Args) == 0 && ce.Portal.Receiver == "" {
|
||||
if len(ce.Args) == 0 {
|
||||
relay = ce.User.GetDefaultLogin()
|
||||
isLoggedIn := relay != nil
|
||||
if onlySetDefaultRelays {
|
||||
|
|
@ -73,19 +73,9 @@ func fnSetRelay(ce *Event) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
var targetID networkid.UserLoginID
|
||||
if ce.Portal.Receiver != "" {
|
||||
targetID = ce.Portal.Receiver
|
||||
if len(ce.Args) > 0 && ce.Args[0] != string(targetID) {
|
||||
ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
targetID = networkid.UserLoginID(ce.Args[0])
|
||||
}
|
||||
relay = ce.Bridge.GetCachedUserLoginByID(targetID)
|
||||
relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
|
||||
if relay == nil {
|
||||
ce.Reply("User login with ID `%s` not found", targetID)
|
||||
ce.Reply("User login with ID `%s` not found", ce.Args[0])
|
||||
return
|
||||
} else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) {
|
||||
// All good
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -8,21 +8,13 @@ package commands
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/bridgev2/provisionutil"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
|
@ -38,35 +30,6 @@ var CommandResolveIdentifier = &FullHandler{
|
|||
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
|
||||
}
|
||||
|
||||
var CommandSyncChat = &FullHandler{
|
||||
Func: func(ce *Event) {
|
||||
login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to find login for sync")
|
||||
ce.Reply("Failed to find login: %v", err)
|
||||
return
|
||||
} else if login == nil {
|
||||
ce.Reply("No login found for sync")
|
||||
return
|
||||
}
|
||||
info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to get chat info for sync")
|
||||
ce.Reply("Failed to get chat info: %v", err)
|
||||
return
|
||||
}
|
||||
ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{})
|
||||
ce.React("✅️")
|
||||
},
|
||||
Name: "sync-portal",
|
||||
Help: HelpMeta{
|
||||
Section: HelpSectionChats,
|
||||
Description: "Sync the current portal room",
|
||||
},
|
||||
RequiresPortal: true,
|
||||
RequiresLogin: true,
|
||||
}
|
||||
|
||||
var CommandStartChat = &FullHandler{
|
||||
Func: fnResolveIdentifier,
|
||||
Name: "start-chat",
|
||||
|
|
@ -80,15 +43,9 @@ var CommandStartChat = &FullHandler{
|
|||
NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI],
|
||||
}
|
||||
|
||||
func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
|
||||
var remainingArgs []string
|
||||
if len(ce.Args) > 1 {
|
||||
remainingArgs = ce.Args[1:]
|
||||
}
|
||||
var login *bridgev2.UserLogin
|
||||
if len(ce.Args) > 0 {
|
||||
login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
|
||||
}
|
||||
func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) {
|
||||
remainingArgs := ce.Args[1:]
|
||||
login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0]))
|
||||
if login == nil || login.UserMXID != ce.User.MXID {
|
||||
remainingArgs = ce.Args
|
||||
login = ce.User.GetDefaultLogin()
|
||||
|
|
@ -100,13 +57,24 @@ func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*
|
|||
return login, api, remainingArgs
|
||||
}
|
||||
|
||||
func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string {
|
||||
if resp.MXID != "" {
|
||||
return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL())
|
||||
} else if resp.Name != "" {
|
||||
return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name)
|
||||
func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string {
|
||||
var targetName string
|
||||
var targetMXID id.UserID
|
||||
if resp.Ghost != nil {
|
||||
if resp.UserInfo != nil {
|
||||
resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
|
||||
}
|
||||
targetName = resp.Ghost.Name
|
||||
targetMXID = resp.Ghost.Intent.GetMXID()
|
||||
} else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
|
||||
targetName = *resp.UserInfo.Name
|
||||
}
|
||||
if targetMXID != "" {
|
||||
return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL())
|
||||
} else if targetName != "" {
|
||||
return fmt.Sprintf("`%s` / %s", resp.UserID, targetName)
|
||||
} else {
|
||||
return fmt.Sprintf("`%s`", resp.ID)
|
||||
return fmt.Sprintf("`%s`", resp.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -119,137 +87,65 @@ func fnResolveIdentifier(ce *Event) {
|
|||
if api == nil {
|
||||
return
|
||||
}
|
||||
allLogins := ce.User.GetUserLogins()
|
||||
createChat := ce.Command == "start-chat" || ce.Command == "pm"
|
||||
identifier := strings.Join(identifierParts, " ")
|
||||
resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat)
|
||||
for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ {
|
||||
resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat)
|
||||
}
|
||||
resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to resolve identifier")
|
||||
ce.Reply("Failed to resolve identifier: %v", err)
|
||||
return
|
||||
} else if resp == nil {
|
||||
ce.ReplyAdvanced(fmt.Sprintf("Identifier <code>%s</code> not found", html.EscapeString(identifier)), false, true)
|
||||
return
|
||||
}
|
||||
formattedName := formatResolveIdentifierResult(resp)
|
||||
formattedName := formatResolveIdentifierResult(ce.Ctx, resp)
|
||||
if createChat {
|
||||
name := resp.Portal.Name
|
||||
if name == "" {
|
||||
name = resp.Portal.MXID.String()
|
||||
if resp.Chat == nil {
|
||||
ce.Reply("Interface error: network connector did not return chat for create chat request")
|
||||
return
|
||||
}
|
||||
if !resp.JustCreated {
|
||||
ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
|
||||
portal := resp.Chat.Portal
|
||||
if portal == nil {
|
||||
portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to get portal")
|
||||
ce.Reply("Failed to get portal: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if resp.Chat.PortalInfo == nil {
|
||||
resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to get portal info")
|
||||
ce.Reply("Failed to get portal info: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if portal.MXID != "" {
|
||||
name := portal.Name
|
||||
if name == "" {
|
||||
name = portal.MXID.String()
|
||||
}
|
||||
portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{})
|
||||
ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
|
||||
} else {
|
||||
ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL())
|
||||
err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to create room")
|
||||
ce.Reply("Failed to create room: %v", err)
|
||||
return
|
||||
}
|
||||
name := portal.Name
|
||||
if name == "" {
|
||||
name = portal.MXID.String()
|
||||
}
|
||||
ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL())
|
||||
}
|
||||
} else {
|
||||
ce.Reply("Found %s", formattedName)
|
||||
}
|
||||
}
|
||||
|
||||
var CommandCreateGroup = &FullHandler{
|
||||
Func: fnCreateGroup,
|
||||
Name: "create-group",
|
||||
Aliases: []string{"create"},
|
||||
Help: HelpMeta{
|
||||
Section: HelpSectionChats,
|
||||
Description: "Create a new group chat for the current Matrix room",
|
||||
Args: "[_group type_]",
|
||||
},
|
||||
RequiresLogin: true,
|
||||
NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI],
|
||||
}
|
||||
|
||||
func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) {
|
||||
evt, err := provider.GetStateEvent(ctx, roomID, evtType, "")
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation")
|
||||
} else if evt != nil {
|
||||
content, _ = evt.Content.Parsed.(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fnCreateGroup(ce *Event) {
|
||||
ce.Bridge.Matrix.GetCapabilities()
|
||||
login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group")
|
||||
if api == nil {
|
||||
return
|
||||
}
|
||||
stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState)
|
||||
if !ok {
|
||||
ce.Reply("Matrix connector doesn't support fetching room state")
|
||||
return
|
||||
}
|
||||
members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to get room members for group creation")
|
||||
ce.Reply("Failed to get room members: %v", err)
|
||||
return
|
||||
}
|
||||
caps := ce.Bridge.Network.GetCapabilities()
|
||||
params := &bridgev2.GroupCreateParams{
|
||||
Username: "",
|
||||
Participants: make([]networkid.UserID, 0, len(members)-2),
|
||||
Parent: nil, // TODO check space parent event
|
||||
Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider),
|
||||
Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider),
|
||||
Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider),
|
||||
Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider),
|
||||
RoomID: ce.RoomID,
|
||||
}
|
||||
for userID, member := range members {
|
||||
if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() {
|
||||
continue
|
||||
}
|
||||
if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok {
|
||||
params.Participants = append(params.Participants, parsedUserID)
|
||||
} else if !ce.Bridge.Config.SplitPortals {
|
||||
if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil {
|
||||
ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member")
|
||||
} else if user != nil {
|
||||
// TODO add user logins to participants
|
||||
//for _, login := range user.GetUserLogins() {
|
||||
// params.Participants = append(params.Participants, login.GetUserID())
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(caps.Provisioning.GroupCreation) == 0 {
|
||||
ce.Reply("No group creation types defined in network capabilities")
|
||||
return
|
||||
} else if len(remainingArgs) > 0 {
|
||||
params.Type = remainingArgs[0]
|
||||
} else if len(caps.Provisioning.GroupCreation) == 1 {
|
||||
for params.Type = range caps.Provisioning.GroupCreation {
|
||||
// The loop assigns the variable we want
|
||||
}
|
||||
} else {
|
||||
types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `")
|
||||
ce.Reply("Please specify type of group to create: `%s`", types)
|
||||
return
|
||||
}
|
||||
resp, err := provisionutil.CreateGroup(ce.Ctx, login, params)
|
||||
if err != nil {
|
||||
ce.Reply("Failed to create group: %v", err)
|
||||
return
|
||||
}
|
||||
var postfix string
|
||||
if len(resp.FailedParticipants) > 0 {
|
||||
failedParticipantsStrings := make([]string, len(resp.FailedParticipants))
|
||||
i := 0
|
||||
for participantID, meta := range resp.FailedParticipants {
|
||||
failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason)
|
||||
i++
|
||||
}
|
||||
postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n")
|
||||
}
|
||||
ce.Reply("Successfully created group `%s`%s", resp.ID, postfix)
|
||||
}
|
||||
|
||||
var CommandSearch = &FullHandler{
|
||||
Func: fnSearch,
|
||||
Name: "search",
|
||||
|
|
@ -267,67 +163,35 @@ func fnSearch(ce *Event) {
|
|||
ce.Reply("Usage: `$cmdprefix search <query>`")
|
||||
return
|
||||
}
|
||||
login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
|
||||
_, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users")
|
||||
if api == nil {
|
||||
return
|
||||
}
|
||||
resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " "))
|
||||
results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " "))
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Msg("Failed to search for users")
|
||||
ce.Reply("Failed to search for users: %v", err)
|
||||
return
|
||||
}
|
||||
resultsString := make([]string, len(resp.Results))
|
||||
for i, res := range resp.Results {
|
||||
formattedName := formatResolveIdentifierResult(res)
|
||||
resultsString := make([]string, len(results))
|
||||
for i, res := range results {
|
||||
formattedName := formatResolveIdentifierResult(ce.Ctx, res)
|
||||
resultsString[i] = fmt.Sprintf("* %s", formattedName)
|
||||
if res.Portal != nil && res.Portal.MXID != "" {
|
||||
portalName := res.Portal.Name
|
||||
if portalName == "" {
|
||||
portalName = res.Portal.MXID.String()
|
||||
if res.Chat != nil {
|
||||
if res.Chat.Portal == nil {
|
||||
res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey)
|
||||
if err != nil {
|
||||
ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal")
|
||||
}
|
||||
}
|
||||
if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" {
|
||||
portalName := res.Chat.Portal.Name
|
||||
if portalName == "" {
|
||||
portalName = res.Chat.Portal.MXID.String()
|
||||
}
|
||||
resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL())
|
||||
}
|
||||
resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL())
|
||||
}
|
||||
}
|
||||
ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n"))
|
||||
}
|
||||
|
||||
var CommandMute = &FullHandler{
|
||||
Func: fnMute,
|
||||
Name: "mute",
|
||||
Aliases: []string{"unmute"},
|
||||
Help: HelpMeta{
|
||||
Section: HelpSectionChats,
|
||||
Description: "Mute or unmute a chat on the remote network",
|
||||
Args: "[duration]",
|
||||
},
|
||||
RequiresPortal: true,
|
||||
RequiresLogin: true,
|
||||
NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI],
|
||||
}
|
||||
|
||||
func fnMute(ce *Event) {
|
||||
_, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats")
|
||||
var mutedUntil int64
|
||||
if ce.Command == "mute" {
|
||||
mutedUntil = -1
|
||||
if len(ce.Args) > 0 {
|
||||
duration, err := time.ParseDuration(ce.Args[0])
|
||||
if err != nil {
|
||||
ce.Reply("Invalid duration: %v", err)
|
||||
return
|
||||
}
|
||||
mutedUntil = time.Now().Add(duration).UnixMilli()
|
||||
}
|
||||
}
|
||||
err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{
|
||||
MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{
|
||||
Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil},
|
||||
Portal: ce.Portal,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
ce.Reply("Failed to %s chat: %v", ce.Command, err)
|
||||
} else {
|
||||
ce.React("✅️")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,13 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"golang.org/x/exp/constraints"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
|
||||
|
|
@ -28,7 +34,6 @@ type Database struct {
|
|||
UserPortal *UserPortalQuery
|
||||
BackfillTask *BackfillTaskQuery
|
||||
KV *KVQuery
|
||||
PublicMedia *PublicMediaQuery
|
||||
}
|
||||
|
||||
type MetaMerger interface {
|
||||
|
|
@ -136,12 +141,6 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa
|
|||
BridgeID: bridgeID,
|
||||
Database: db,
|
||||
},
|
||||
PublicMedia: &PublicMediaQuery{
|
||||
BridgeID: bridgeID,
|
||||
QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia {
|
||||
return &PublicMedia{}
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -152,3 +151,55 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID)
|
|||
panic("bridge ID mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) {
|
||||
if val, found := m[key]; found {
|
||||
floatVal, ok := val.(float64)
|
||||
if ok {
|
||||
return T(floatVal), true
|
||||
}
|
||||
tVal, ok := val.(T)
|
||||
if ok {
|
||||
return tVal, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func unmarshalMerge(input []byte, data any, extra *map[string]any) error {
|
||||
err := json.Unmarshal(input, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(input, extra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *extra == nil {
|
||||
*extra = make(map[string]any)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalMerge(data any, extra map[string]any) ([]byte, error) {
|
||||
if extra == nil {
|
||||
return json.Marshal(data)
|
||||
}
|
||||
merged := make(map[string]any)
|
||||
maps.Copy(merged, extra)
|
||||
dataRef := reflect.ValueOf(data).Elem()
|
||||
dataType := dataRef.Type()
|
||||
for _, field := range reflect.VisibleFields(dataType) {
|
||||
parts := strings.Split(field.Tag.Get("json"), ",")
|
||||
if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" {
|
||||
continue
|
||||
}
|
||||
fieldVal := dataRef.FieldByIndex(field.Index)
|
||||
if fieldVal.IsZero() {
|
||||
delete(merged, parts[0])
|
||||
} else {
|
||||
merged[parts[0]] = fieldVal.Interface()
|
||||
}
|
||||
}
|
||||
return json.Marshal(merged)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,92 +12,54 @@ import (
|
|||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/jsontime"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Deprecated: use [event.DisappearingType]
|
||||
type DisappearingType = event.DisappearingType
|
||||
// DisappearingType represents the type of a disappearing message timer.
|
||||
type DisappearingType string
|
||||
|
||||
// Deprecated: use constants in event package
|
||||
const (
|
||||
DisappearingTypeNone = event.DisappearingTypeNone
|
||||
DisappearingTypeAfterRead = event.DisappearingTypeAfterRead
|
||||
DisappearingTypeAfterSend = event.DisappearingTypeAfterSend
|
||||
DisappearingTypeNone DisappearingType = ""
|
||||
DisappearingTypeAfterRead DisappearingType = "after_read"
|
||||
DisappearingTypeAfterSend DisappearingType = "after_send"
|
||||
)
|
||||
|
||||
// DisappearingSetting represents a disappearing message timer setting
|
||||
// by combining a type with a timer and an optional start timestamp.
|
||||
type DisappearingSetting struct {
|
||||
Type event.DisappearingType
|
||||
Type DisappearingType
|
||||
Timer time.Duration
|
||||
DisappearAt time.Time
|
||||
}
|
||||
|
||||
func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting {
|
||||
if evt == nil || evt.Type == event.DisappearingTypeNone {
|
||||
return DisappearingSetting{}
|
||||
}
|
||||
return DisappearingSetting{
|
||||
Type: evt.Type,
|
||||
Timer: evt.Timer.Duration,
|
||||
}
|
||||
}
|
||||
|
||||
func (ds DisappearingSetting) Normalize() DisappearingSetting {
|
||||
if ds.Type == event.DisappearingTypeNone {
|
||||
ds.Timer = 0
|
||||
} else if ds.Timer == 0 {
|
||||
ds.Type = event.DisappearingTypeNone
|
||||
}
|
||||
return ds
|
||||
}
|
||||
|
||||
func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting {
|
||||
ds.DisappearAt = start.Add(ds.Timer)
|
||||
return ds
|
||||
}
|
||||
|
||||
func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer {
|
||||
if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 {
|
||||
return &event.BeeperDisappearingTimer{}
|
||||
}
|
||||
return &event.BeeperDisappearingTimer{
|
||||
Type: ds.Type,
|
||||
Timer: jsontime.MS(ds.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
type DisappearingMessageQuery struct {
|
||||
BridgeID networkid.BridgeID
|
||||
*dbutil.QueryHelper[*DisappearingMessage]
|
||||
}
|
||||
|
||||
type DisappearingMessage struct {
|
||||
BridgeID networkid.BridgeID
|
||||
RoomID id.RoomID
|
||||
EventID id.EventID
|
||||
Timestamp time.Time
|
||||
BridgeID networkid.BridgeID
|
||||
RoomID id.RoomID
|
||||
EventID id.EventID
|
||||
DisappearingSetting
|
||||
}
|
||||
|
||||
const (
|
||||
upsertDisappearingMessageQuery = `
|
||||
INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at
|
||||
`
|
||||
startDisappearingMessagesQuery = `
|
||||
UPDATE disappearing_message
|
||||
SET disappear_at=$1 + timer
|
||||
WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4
|
||||
RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
|
||||
WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read'
|
||||
RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at
|
||||
`
|
||||
getUpcomingDisappearingMessagesQuery = `
|
||||
SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at
|
||||
SELECT bridge_id, mx_room, mxid, type, timer, disappear_at
|
||||
FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2
|
||||
ORDER BY disappear_at LIMIT $3
|
||||
`
|
||||
|
|
@ -111,8 +73,8 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe
|
|||
return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) {
|
||||
return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano())
|
||||
func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) {
|
||||
return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID)
|
||||
}
|
||||
|
||||
func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) {
|
||||
|
|
@ -124,19 +86,17 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even
|
|||
}
|
||||
|
||||
func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
|
||||
var timestamp int64
|
||||
var disappearAt sql.NullInt64
|
||||
err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt)
|
||||
err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if disappearAt.Valid {
|
||||
d.DisappearAt = time.Unix(0, disappearAt.Int64)
|
||||
}
|
||||
d.Timestamp = time.Unix(0, timestamp)
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *DisappearingMessage) sqlVariables() []any {
|
||||
return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
|
||||
return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,17 +7,12 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exerrors"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
|
@ -27,55 +22,6 @@ type GhostQuery struct {
|
|||
*dbutil.QueryHelper[*Ghost]
|
||||
}
|
||||
|
||||
type ExtraProfile map[string]json.RawMessage
|
||||
|
||||
func (ep *ExtraProfile) Set(key string, value any) error {
|
||||
if key == "displayname" || key == "avatar_url" {
|
||||
return fmt.Errorf("cannot set reserved profile key %q", key)
|
||||
}
|
||||
marshaled, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if *ep == nil {
|
||||
*ep = make(ExtraProfile)
|
||||
}
|
||||
(*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ep *ExtraProfile) With(key string, value any) *ExtraProfile {
|
||||
exerrors.PanicIfNotNil(ep.Set(key, value))
|
||||
return ep
|
||||
}
|
||||
|
||||
func canonicalizeIfObject(data json.RawMessage) json.RawMessage {
|
||||
if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
|
||||
return canonicaljson.CanonicalJSONAssumeValid(data)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) {
|
||||
if len(*ep) == 0 {
|
||||
return
|
||||
}
|
||||
if *dest == nil {
|
||||
*dest = make(ExtraProfile)
|
||||
}
|
||||
for key, val := range *ep {
|
||||
if key == "displayname" || key == "avatar_url" {
|
||||
continue
|
||||
}
|
||||
existing, exists := (*dest)[key]
|
||||
if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) {
|
||||
(*dest)[key] = val
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Ghost struct {
|
||||
BridgeID networkid.BridgeID
|
||||
ID networkid.UserID
|
||||
|
|
@ -89,14 +35,13 @@ type Ghost struct {
|
|||
ContactInfoSet bool
|
||||
IsBot bool
|
||||
Identifiers []string
|
||||
ExtraProfile ExtraProfile
|
||||
Metadata any
|
||||
}
|
||||
|
||||
const (
|
||||
getGhostBaseQuery = `
|
||||
SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
|
||||
name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
|
||||
name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
|
||||
FROM ghost
|
||||
`
|
||||
getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2`
|
||||
|
|
@ -104,14 +49,13 @@ const (
|
|||
insertGhostQuery = `
|
||||
INSERT INTO ghost (
|
||||
bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc,
|
||||
name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata
|
||||
name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
`
|
||||
updateGhostQuery = `
|
||||
UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6,
|
||||
name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10,
|
||||
identifiers=$11, extra_profile=$12, metadata=$13
|
||||
name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12
|
||||
WHERE bridge_id=$1 AND id=$2
|
||||
`
|
||||
)
|
||||
|
|
@ -142,7 +86,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) {
|
|||
&g.BridgeID, &g.ID,
|
||||
&g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC,
|
||||
&g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot,
|
||||
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
|
||||
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -172,6 +116,6 @@ func (g *Ghost) sqlVariables() []any {
|
|||
g.BridgeID, g.ID,
|
||||
g.Name, g.AvatarID, avatarHash, g.AvatarMXC,
|
||||
g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot,
|
||||
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata},
|
||||
dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ const (
|
|||
KeySplitPortalsEnabled Key = "split_portals_enabled"
|
||||
KeyBridgeInfoVersion Key = "bridge_info_version"
|
||||
KeyEncryptionStateResynced Key = "encryption_state_resynced"
|
||||
KeyRecoveryKey Key = "recovery_key"
|
||||
)
|
||||
|
||||
type KVQuery struct {
|
||||
|
|
|
|||
|
|
@ -11,12 +11,9 @@ import (
|
|||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
|
|
@ -27,7 +24,6 @@ type MessageQuery struct {
|
|||
BridgeID networkid.BridgeID
|
||||
MetaType MetaTypeCreator
|
||||
*dbutil.QueryHelper[*Message]
|
||||
chunkDeleteLock sync.Mutex
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
|
|
@ -68,8 +64,8 @@ const (
|
|||
getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1`
|
||||
getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5`
|
||||
getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1`
|
||||
getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1`
|
||||
getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1`
|
||||
getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1`
|
||||
getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1`
|
||||
getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4`
|
||||
|
||||
getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1`
|
||||
|
|
@ -100,10 +96,6 @@ const (
|
|||
deleteMessagePartByRowIDQuery = `
|
||||
DELETE FROM message WHERE bridge_id=$1 AND rowid=$2
|
||||
`
|
||||
deleteMessageChunkQuery = `
|
||||
DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5
|
||||
`
|
||||
getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1`
|
||||
)
|
||||
|
||||
func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) {
|
||||
|
|
@ -188,85 +180,6 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error {
|
|||
return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID)
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) {
|
||||
res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) {
|
||||
err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID)
|
||||
return
|
||||
}
|
||||
|
||||
const deleteChunkSize = 100_000
|
||||
|
||||
func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error {
|
||||
if mq.GetDB().Dialect != dbutil.SQLite {
|
||||
return nil
|
||||
}
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("action", "delete messages in chunks").
|
||||
Stringer("portal_key", portal).
|
||||
Logger()
|
||||
if !mq.chunkDeleteLock.TryLock() {
|
||||
log.Warn().Msg("Portal deletion lock is being held, waiting...")
|
||||
mq.chunkDeleteLock.Lock()
|
||||
log.Debug().Msg("Acquired portal deletion lock after waiting")
|
||||
}
|
||||
defer mq.chunkDeleteLock.Unlock()
|
||||
total, err := mq.CountMessagesInPortal(ctx, portal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count messages in portal: %w", err)
|
||||
} else if total < deleteChunkSize/3 {
|
||||
return nil
|
||||
}
|
||||
globalMaxRowID, err := mq.getMaxRowID(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get max row ID: %w", err)
|
||||
}
|
||||
log.Debug().
|
||||
Int("total_count", total).
|
||||
Int64("global_max_row_id", globalMaxRowID).
|
||||
Msg("Portal has lots of messages, deleting in chunks to avoid database locks")
|
||||
maxRowID := int64(deleteChunkSize)
|
||||
globalMaxRowID += deleteChunkSize * 1.2
|
||||
var dbTimeUsed time.Duration
|
||||
globalStart := time.Now()
|
||||
for total > 500 && maxRowID < globalMaxRowID {
|
||||
start := time.Now()
|
||||
count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID)
|
||||
duration := time.Since(start)
|
||||
dbTimeUsed += duration
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err)
|
||||
}
|
||||
total -= int(count)
|
||||
maxRowID += deleteChunkSize
|
||||
sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond))
|
||||
log.Debug().
|
||||
Int64("max_row_id", maxRowID).
|
||||
Int64("deleted_count", count).
|
||||
Int("remaining_count", total).
|
||||
Dur("duration", duration).
|
||||
Dur("sleep_time", sleepTime).
|
||||
Msg("Deleted chunk of messages")
|
||||
select {
|
||||
case <-time.After(sleepTime):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
log.Debug().
|
||||
Int("remaining_count", total).
|
||||
Dur("db_time_used", dbTimeUsed).
|
||||
Dur("total_duration", time.Since(globalStart)).
|
||||
Msg("Finished chunked delete of messages in portal")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) {
|
||||
err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
|
@ -35,20 +34,9 @@ type PortalQuery struct {
|
|||
*dbutil.QueryHelper[*Portal]
|
||||
}
|
||||
|
||||
type CapStateFlags uint32
|
||||
|
||||
func (csf CapStateFlags) Has(flag CapStateFlags) bool {
|
||||
return csf&flag != 0
|
||||
}
|
||||
|
||||
const (
|
||||
CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota
|
||||
)
|
||||
|
||||
type CapabilityState struct {
|
||||
Source networkid.UserLoginID `json:"source"`
|
||||
ID string `json:"id"`
|
||||
Flags CapStateFlags `json:"flags"`
|
||||
}
|
||||
|
||||
type Portal struct {
|
||||
|
|
@ -56,31 +44,30 @@ type Portal struct {
|
|||
networkid.PortalKey
|
||||
MXID id.RoomID
|
||||
|
||||
ParentKey networkid.PortalKey
|
||||
RelayLoginID networkid.UserLoginID
|
||||
OtherUserID networkid.UserID
|
||||
Name string
|
||||
Topic string
|
||||
AvatarID networkid.AvatarID
|
||||
AvatarHash [32]byte
|
||||
AvatarMXC id.ContentURIString
|
||||
NameSet bool
|
||||
TopicSet bool
|
||||
AvatarSet bool
|
||||
NameIsCustom bool
|
||||
InSpace bool
|
||||
MessageRequest bool
|
||||
RoomType RoomType
|
||||
Disappear DisappearingSetting
|
||||
CapState CapabilityState
|
||||
Metadata any
|
||||
ParentKey networkid.PortalKey
|
||||
RelayLoginID networkid.UserLoginID
|
||||
OtherUserID networkid.UserID
|
||||
Name string
|
||||
Topic string
|
||||
AvatarID networkid.AvatarID
|
||||
AvatarHash [32]byte
|
||||
AvatarMXC id.ContentURIString
|
||||
NameSet bool
|
||||
TopicSet bool
|
||||
AvatarSet bool
|
||||
NameIsCustom bool
|
||||
InSpace bool
|
||||
RoomType RoomType
|
||||
Disappear DisappearingSetting
|
||||
CapState CapabilityState
|
||||
Metadata any
|
||||
}
|
||||
|
||||
const (
|
||||
getPortalBaseQuery = `
|
||||
SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id,
|
||||
name, topic, avatar_id, avatar_hash, avatar_mxc,
|
||||
name_set, topic_set, avatar_set, name_is_custom, in_space, message_request,
|
||||
name_set, topic_set, avatar_set, name_is_custom, in_space,
|
||||
room_type, disappear_type, disappear_timer, cap_state,
|
||||
metadata
|
||||
FROM portal
|
||||
|
|
@ -89,9 +76,7 @@ const (
|
|||
getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')`
|
||||
getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2`
|
||||
getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL`
|
||||
getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC`
|
||||
getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2`
|
||||
getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3`
|
||||
getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1`
|
||||
getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3`
|
||||
|
||||
|
|
@ -102,11 +87,11 @@ const (
|
|||
bridge_id, id, receiver, mxid,
|
||||
parent_id, parent_receiver, relay_login_id, other_user_id,
|
||||
name, topic, avatar_id, avatar_hash, avatar_mxc,
|
||||
name_set, avatar_set, topic_set, name_is_custom, in_space, message_request,
|
||||
name_set, avatar_set, topic_set, name_is_custom, in_space,
|
||||
room_type, disappear_type, disappear_timer, cap_state,
|
||||
metadata, relay_bridge_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24,
|
||||
$1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23,
|
||||
CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END
|
||||
)
|
||||
`
|
||||
|
|
@ -115,8 +100,8 @@ const (
|
|||
SET mxid=$4, parent_id=$5, parent_receiver=$6,
|
||||
relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END,
|
||||
other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13,
|
||||
name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19,
|
||||
room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24
|
||||
name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18,
|
||||
room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23
|
||||
WHERE bridge_id=$1 AND id=$2 AND receiver=$3
|
||||
`
|
||||
deletePortalQuery = `
|
||||
|
|
@ -126,33 +111,15 @@ const (
|
|||
reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3`
|
||||
migrateToSplitPortalsQuery = `
|
||||
UPDATE portal
|
||||
SET receiver=new_receiver
|
||||
FROM (
|
||||
SELECT bridge_id, id, COALESCE((
|
||||
SELECT login_id
|
||||
FROM user_portal
|
||||
WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
|
||||
LIMIT 1
|
||||
), (
|
||||
SELECT login_id
|
||||
FROM user_portal
|
||||
WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id
|
||||
LIMIT 1
|
||||
), (
|
||||
SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
|
||||
), '') AS new_receiver
|
||||
FROM portal
|
||||
WHERE receiver='' AND bridge_id=$1
|
||||
) updates
|
||||
WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS (
|
||||
SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver
|
||||
)
|
||||
`
|
||||
fixParentsAfterSplitPortalMigrationQuery = `
|
||||
UPDATE portal
|
||||
SET parent_receiver=receiver
|
||||
WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''
|
||||
AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver);
|
||||
SET receiver=COALESCE((
|
||||
SELECT login_id
|
||||
FROM user_portal
|
||||
WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver=''
|
||||
LIMIT 1
|
||||
), (
|
||||
SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1
|
||||
), '')
|
||||
WHERE receiver='' AND bridge_id=$1
|
||||
`
|
||||
)
|
||||
|
||||
|
|
@ -180,10 +147,6 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
|
|||
return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID)
|
||||
}
|
||||
|
|
@ -192,10 +155,6 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid.
|
|||
return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) {
|
||||
return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) {
|
||||
return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver)
|
||||
}
|
||||
|
|
@ -226,14 +185,6 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error)
|
|||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) {
|
||||
res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
||||
var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString
|
||||
var disappearTimer sql.NullInt64
|
||||
|
|
@ -242,7 +193,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
|||
&p.BridgeID, &p.ID, &p.Receiver, &mxid,
|
||||
&parentID, &parentReceiver, &relayLoginID, &otherUserID,
|
||||
&p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC,
|
||||
&p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest,
|
||||
&p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace,
|
||||
&p.RoomType, &disappearType, &disappearTimer,
|
||||
dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata},
|
||||
)
|
||||
|
|
@ -257,7 +208,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) {
|
|||
}
|
||||
if disappearType.Valid {
|
||||
p.Disappear = DisappearingSetting{
|
||||
Type: event.DisappearingType(disappearType.String),
|
||||
Type: DisappearingType(disappearType.String),
|
||||
Timer: time.Duration(disappearTimer.Int64),
|
||||
}
|
||||
}
|
||||
|
|
@ -289,7 +240,7 @@ func (p *Portal) sqlVariables() []any {
|
|||
p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID),
|
||||
dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID),
|
||||
p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC,
|
||||
p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest,
|
||||
p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace,
|
||||
p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer),
|
||||
dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type PublicMediaQuery struct {
|
||||
BridgeID networkid.BridgeID
|
||||
*dbutil.QueryHelper[*PublicMedia]
|
||||
}
|
||||
|
||||
type PublicMedia struct {
|
||||
BridgeID networkid.BridgeID
|
||||
PublicID string
|
||||
MXC id.ContentURI
|
||||
Keys *attachment.EncryptedFile
|
||||
MimeType string
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
upsertPublicMediaQuery = `
|
||||
INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry
|
||||
`
|
||||
getPublicMediaQuery = `
|
||||
SELECT bridge_id, public_id, mxc, keys, mimetype, expiry
|
||||
FROM public_media WHERE bridge_id=$1 AND public_id=$2
|
||||
`
|
||||
)
|
||||
|
||||
func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error {
|
||||
ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID)
|
||||
return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) {
|
||||
return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID)
|
||||
}
|
||||
|
||||
func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) {
|
||||
var expiry sql.NullInt64
|
||||
var mimetype sql.NullString
|
||||
err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expiry.Valid {
|
||||
pm.Expiry = time.Unix(0, expiry.Int64)
|
||||
}
|
||||
pm.MimeType = mimetype.String
|
||||
return pm, nil
|
||||
}
|
||||
|
||||
func (pm *PublicMedia) sqlVariables() []any {
|
||||
return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
-- v0 -> v27 (compatible with v9+): Latest revision
|
||||
-- v0 -> v22 (compatible with v9+): Latest revision
|
||||
CREATE TABLE "user" (
|
||||
bridge_id TEXT NOT NULL,
|
||||
mxid TEXT NOT NULL,
|
||||
|
|
@ -48,7 +48,6 @@ CREATE TABLE portal (
|
|||
topic_set BOOLEAN NOT NULL,
|
||||
name_is_custom BOOLEAN NOT NULL DEFAULT false,
|
||||
in_space BOOLEAN NOT NULL,
|
||||
message_request BOOLEAN NOT NULL DEFAULT false,
|
||||
room_type TEXT NOT NULL,
|
||||
disappear_type TEXT,
|
||||
disappear_timer BIGINT,
|
||||
|
|
@ -65,7 +64,6 @@ CREATE TABLE portal (
|
|||
ON DELETE SET NULL ON UPDATE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid);
|
||||
CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
|
||||
|
||||
CREATE TABLE ghost (
|
||||
bridge_id TEXT NOT NULL,
|
||||
|
|
@ -80,7 +78,6 @@ CREATE TABLE ghost (
|
|||
contact_info_set BOOLEAN NOT NULL,
|
||||
is_bot BOOLEAN NOT NULL,
|
||||
identifiers jsonb NOT NULL,
|
||||
extra_profile jsonb,
|
||||
metadata jsonb NOT NULL,
|
||||
|
||||
PRIMARY KEY (bridge_id, id)
|
||||
|
|
@ -130,7 +127,6 @@ CREATE TABLE disappearing_message (
|
|||
bridge_id TEXT NOT NULL,
|
||||
mx_room TEXT NOT NULL,
|
||||
mxid TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL DEFAULT 0,
|
||||
type TEXT NOT NULL,
|
||||
timer BIGINT NOT NULL,
|
||||
disappear_at BIGINT,
|
||||
|
|
@ -141,7 +137,6 @@ CREATE TABLE disappearing_message (
|
|||
REFERENCES portal (bridge_id, mxid)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
|
||||
|
||||
CREATE TABLE reaction (
|
||||
bridge_id TEXT NOT NULL,
|
||||
|
|
@ -220,14 +215,3 @@ CREATE TABLE kv_store (
|
|||
|
||||
PRIMARY KEY (bridge_id, key)
|
||||
);
|
||||
|
||||
CREATE TABLE public_media (
|
||||
bridge_id TEXT NOT NULL,
|
||||
public_id TEXT NOT NULL,
|
||||
mxc TEXT NOT NULL,
|
||||
keys jsonb,
|
||||
mimetype TEXT,
|
||||
expiry BIGINT,
|
||||
|
||||
PRIMARY KEY (bridge_id, public_id)
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
-- v23 (compatible with v9+): Add event timestamp for disappearing messages
|
||||
ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0;
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
-- v24 (compatible with v9+): Custom URLs for public media
|
||||
CREATE TABLE public_media (
|
||||
bridge_id TEXT NOT NULL,
|
||||
public_id TEXT NOT NULL,
|
||||
mxc TEXT NOT NULL,
|
||||
keys jsonb,
|
||||
mimetype TEXT,
|
||||
expiry BIGINT,
|
||||
|
||||
PRIMARY KEY (bridge_id, public_id)
|
||||
);
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
-- v25 (compatible with v9+): Flag for message request portals
|
||||
ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false;
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents
|
||||
CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room);
|
||||
CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver);
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
-- v27 (compatible with v9+): Add column for extra ghost profile metadata
|
||||
ALTER TABLE ghost ADD COLUMN extra_profile jsonb;
|
||||
|
|
@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin {
|
|||
|
||||
func (u *UserLogin) sqlVariables() []any {
|
||||
var remoteProfile dbutil.JSON
|
||||
if !u.RemoteProfile.IsZero() {
|
||||
if !u.RemoteProfile.IsEmpty() {
|
||||
remoteProfile.Data = &u.RemoteProfile
|
||||
}
|
||||
return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}}
|
||||
|
|
|
|||
|
|
@ -86,8 +86,8 @@ func (dl *DisappearLoop) Stop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) {
|
||||
startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS)
|
||||
func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) {
|
||||
startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -38,53 +38,35 @@ var ErrNotLoggedIn = errors.New("not logged in")
|
|||
// but direct media is not enabled.
|
||||
var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled")
|
||||
|
||||
var ErrPortalIsDeleted = errors.New("portal is deleted")
|
||||
var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event")
|
||||
|
||||
// Common message status errors
|
||||
var (
|
||||
ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
|
||||
ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
|
||||
ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
|
||||
ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
|
||||
ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
|
||||
ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported)
|
||||
ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
|
||||
ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
|
||||
|
||||
ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
|
||||
ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
|
||||
ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true)
|
||||
|
||||
ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true)
|
||||
ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage()
|
||||
ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false)
|
||||
ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
|
||||
ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
|
||||
ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false)
|
||||
ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
|
||||
ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true)
|
||||
ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
|
||||
ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true)
|
||||
ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
|
||||
ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false)
|
||||
ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
|
||||
ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld)
|
||||
)
|
||||
|
||||
// Common login interface errors
|
||||
|
|
|
|||
|
|
@ -9,15 +9,12 @@ package bridgev2
|
|||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exmime"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/database"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
|
|
@ -137,11 +134,10 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32
|
|||
}
|
||||
|
||||
type UserInfo struct {
|
||||
Identifiers []string
|
||||
Name *string
|
||||
Avatar *Avatar
|
||||
IsBot *bool
|
||||
ExtraProfile database.ExtraProfile
|
||||
Identifiers []string
|
||||
Name *string
|
||||
Avatar *Avatar
|
||||
IsBot *bool
|
||||
|
||||
ExtraUpdates ExtraUpdater[*Ghost]
|
||||
}
|
||||
|
|
@ -189,9 +185,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (ghost *Ghost) getExtraProfileMeta() any {
|
||||
func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra {
|
||||
bridgeName := ghost.Bridge.Network.GetName()
|
||||
baseExtra := &event.BeeperProfileExtra{
|
||||
return &event.BeeperProfileExtra{
|
||||
RemoteID: string(ghost.ID),
|
||||
Identifiers: ghost.Identifiers,
|
||||
Service: bridgeName.BeeperBridgeType,
|
||||
|
|
@ -199,35 +195,23 @@ func (ghost *Ghost) getExtraProfileMeta() any {
|
|||
IsBridgeBot: false,
|
||||
IsNetworkBot: ghost.IsBot,
|
||||
}
|
||||
if len(ghost.ExtraProfile) == 0 {
|
||||
return baseExtra
|
||||
}
|
||||
mergedExtra := maps.Clone(ghost.ExtraProfile)
|
||||
baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra))
|
||||
exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra))
|
||||
return mergedExtra
|
||||
}
|
||||
|
||||
func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool {
|
||||
if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta {
|
||||
ghost.ContactInfoSet = false
|
||||
return false
|
||||
}
|
||||
func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool {
|
||||
if identifiers != nil {
|
||||
slices.Sort(identifiers)
|
||||
}
|
||||
changed := extraProfile.CopyTo(&ghost.ExtraProfile)
|
||||
if ghost.ContactInfoSet &&
|
||||
(identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) &&
|
||||
(isBot == nil || *isBot == ghost.IsBot) {
|
||||
return false
|
||||
}
|
||||
if identifiers != nil {
|
||||
changed = changed || !slices.Equal(identifiers, ghost.Identifiers)
|
||||
ghost.Identifiers = identifiers
|
||||
}
|
||||
if isBot != nil {
|
||||
changed = changed || *isBot != ghost.IsBot
|
||||
ghost.IsBot = *isBot
|
||||
}
|
||||
if ghost.ContactInfoSet && !changed {
|
||||
return false
|
||||
}
|
||||
err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta())
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata")
|
||||
|
|
@ -250,7 +234,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool {
|
|||
}
|
||||
|
||||
func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) {
|
||||
if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
|
||||
if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) {
|
||||
return
|
||||
}
|
||||
info, err := source.Client.GetUserInfo(ctx, ghost)
|
||||
|
|
@ -260,16 +244,12 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin
|
|||
zerolog.Ctx(ctx).Debug().
|
||||
Bool("has_name", ghost.Name != "").
|
||||
Bool("name_set", ghost.NameSet).
|
||||
Bool("has_avatar", ghost.AvatarMXC != "").
|
||||
Bool("avatar_set", ghost.AvatarSet).
|
||||
Msg("Updating ghost info in IfNecessary call")
|
||||
ghost.UpdateInfo(ctx, info)
|
||||
} else {
|
||||
zerolog.Ctx(ctx).Trace().
|
||||
Bool("has_name", ghost.Name != "").
|
||||
Bool("name_set", ghost.NameSet).
|
||||
Bool("has_avatar", ghost.AvatarMXC != "").
|
||||
Bool("avatar_set", ghost.AvatarSet).
|
||||
Msg("No ghost info received in IfNecessary call")
|
||||
}
|
||||
}
|
||||
|
|
@ -297,14 +277,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) {
|
|||
}
|
||||
if info.Avatar != nil {
|
||||
update = ghost.UpdateAvatar(ctx, info.Avatar) || update
|
||||
} else if oldAvatar == "" && !ghost.AvatarSet {
|
||||
// Special case: nil avatar means we're not expecting one ever, if we don't currently have
|
||||
// one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary.
|
||||
ghost.AvatarSet = true
|
||||
update = true
|
||||
}
|
||||
if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil {
|
||||
update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update
|
||||
if info.Identifiers != nil || info.IsBot != nil {
|
||||
update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update
|
||||
}
|
||||
if info.ExtraUpdates != nil {
|
||||
update = info.ExtraUpdates(ctx, ghost) || update
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
)
|
||||
|
||||
// LoginProcess represents a single occurrence of a user logging into the remote network.
|
||||
|
|
@ -179,8 +178,6 @@ const (
|
|||
LoginInputFieldTypeToken LoginInputFieldType = "token"
|
||||
LoginInputFieldTypeURL LoginInputFieldType = "url"
|
||||
LoginInputFieldTypeDomain LoginInputFieldType = "domain"
|
||||
LoginInputFieldTypeSelect LoginInputFieldType = "select"
|
||||
LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code"
|
||||
)
|
||||
|
||||
type LoginInputDataField struct {
|
||||
|
|
@ -192,13 +189,8 @@ type LoginInputDataField struct {
|
|||
Name string `json:"name"`
|
||||
// The description of the field shown to the user.
|
||||
Description string `json:"description"`
|
||||
// A default value that the client can pre-fill the field with.
|
||||
DefaultValue string `json:"default_value,omitempty"`
|
||||
// A regex pattern that the client can use to validate input client-side.
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
// For fields of type select, the valid options.
|
||||
// Pattern may also be filled with a regex that matches the same options.
|
||||
Options []string `json:"options,omitempty"`
|
||||
// A function that validates the input and optionally cleans it up before it's submitted to the connector.
|
||||
Validate func(string) (string, error) `json:"-"`
|
||||
}
|
||||
|
|
@ -273,23 +265,6 @@ func (f *LoginInputDataField) FillDefaultValidate() {
|
|||
type LoginUserInputParams struct {
|
||||
// The fields that the user needs to fill in.
|
||||
Fields []LoginInputDataField `json:"fields"`
|
||||
|
||||
// Attachments to display alongside the input fields.
|
||||
Attachments []*LoginUserInputAttachment `json:"attachments"`
|
||||
}
|
||||
|
||||
type LoginUserInputAttachment struct {
|
||||
Type event.MessageType `json:"type,omitempty"`
|
||||
FileName string `json:"filename,omitempty"`
|
||||
Content []byte `json:"content,omitempty"`
|
||||
Info LoginUserInputAttachmentInfo `json:"info,omitempty"`
|
||||
}
|
||||
|
||||
type LoginUserInputAttachmentInfo struct {
|
||||
MimeType string `json:"mimetype,omitempty"`
|
||||
Width int `json:"w,omitempty"`
|
||||
Height int `json:"h,omitempty"`
|
||||
Size int `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
type LoginCompleteParams struct {
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ import (
|
|||
_ "go.mau.fi/util/dbutil/litestream"
|
||||
"go.mau.fi/util/exbytes"
|
||||
"go.mau.fi/util/exsync"
|
||||
"go.mau.fi/util/ptr"
|
||||
"go.mau.fi/util/random"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
|
|
@ -81,8 +80,6 @@ type Connector struct {
|
|||
|
||||
MediaConfig mautrix.RespMediaConfig
|
||||
SpecVersions *mautrix.RespVersions
|
||||
SpecCaps *mautrix.RespCapabilities
|
||||
specCapsLock sync.Mutex
|
||||
Capabilities *bridgev2.MatrixCapabilities
|
||||
IgnoreUnsupportedServer bool
|
||||
|
||||
|
|
@ -144,20 +141,14 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) {
|
|||
br.EventProcessor.On(event.EventReaction, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent)
|
||||
br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent)
|
||||
br.EventProcessor.On(event.StateMember, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.StateTopic, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent)
|
||||
br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent)
|
||||
br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent)
|
||||
br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent)
|
||||
br.Bot = br.AS.BotIntent()
|
||||
br.Crypto = NewCryptoHelper(br)
|
||||
br.Bridge.Commands.(*commands.Processor).AddHandlers(
|
||||
|
|
@ -282,7 +273,7 @@ func (br *Connector) GetPublicAddress() string {
|
|||
if br.Config.AppService.PublicAddress == "https://bridge.example.com" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimRight(br.Config.AppService.PublicAddress, "/")
|
||||
return br.Config.AppService.PublicAddress
|
||||
}
|
||||
|
||||
func (br *Connector) GetRouter() *http.ServeMux {
|
||||
|
|
@ -344,18 +335,16 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) {
|
|||
}
|
||||
|
||||
func (br *Connector) ensureConnection(ctx context.Context) {
|
||||
triedToRegister := false
|
||||
for {
|
||||
versions, err := br.Bot.Versions(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, mautrix.MForbidden) && !triedToRegister {
|
||||
if errors.Is(err, mautrix.MForbidden) {
|
||||
br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying")
|
||||
err = br.Bot.EnsureRegistered(ctx)
|
||||
if err != nil {
|
||||
br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN")
|
||||
os.Exit(16)
|
||||
}
|
||||
triedToRegister = true
|
||||
} else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) {
|
||||
br.logInitialRequestError(err, "/versions request failed with auth error")
|
||||
os.Exit(16)
|
||||
|
|
@ -368,9 +357,6 @@ func (br *Connector) ensureConnection(ctx context.Context) {
|
|||
*br.AS.SpecVersions = *versions
|
||||
br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites)
|
||||
br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending)
|
||||
br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange)
|
||||
br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) ||
|
||||
(br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -415,21 +401,6 @@ func (br *Connector) ensureConnection(ctx context.Context) {
|
|||
br.Bot.EnsureAppserviceConnection(ctx)
|
||||
}
|
||||
|
||||
func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities {
|
||||
br.specCapsLock.Lock()
|
||||
defer br.specCapsLock.Unlock()
|
||||
if br.SpecCaps != nil {
|
||||
return br.SpecCaps
|
||||
}
|
||||
caps, err := br.Bot.Capabilities(ctx)
|
||||
if err != nil {
|
||||
br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver")
|
||||
return nil
|
||||
}
|
||||
br.SpecCaps = caps
|
||||
return caps
|
||||
}
|
||||
|
||||
func (br *Connector) fetchMediaConfig(ctx context.Context) {
|
||||
cfg, err := br.Bot.GetMediaConfig(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -538,8 +509,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2
|
|||
Msg("Failed to send MSS event")
|
||||
}
|
||||
}
|
||||
if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice &&
|
||||
(ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
|
||||
if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) {
|
||||
content := ms.ToNoticeEvent(evt)
|
||||
if editEvent != "" {
|
||||
content.SetEdit(editEvent)
|
||||
|
|
@ -623,28 +593,13 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve
|
|||
}
|
||||
|
||||
func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) {
|
||||
if stateKey == "" {
|
||||
switch eventType {
|
||||
case event.StateCreate:
|
||||
createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID)
|
||||
if err != nil || createEvt != nil {
|
||||
return createEvt, err
|
||||
}
|
||||
case event.StateJoinRules:
|
||||
joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if joinRulesContent != nil {
|
||||
return &event.Event{
|
||||
Type: event.StateJoinRules,
|
||||
RoomID: roomID,
|
||||
StateKey: ptr.Ptr(""),
|
||||
Content: event.Content{Parsed: joinRulesContent},
|
||||
}, nil
|
||||
}
|
||||
if eventType == event.StateCreate && stateKey == "" {
|
||||
createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID)
|
||||
if err != nil || createEvt != nil {
|
||||
return createEvt, err
|
||||
}
|
||||
}
|
||||
return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey)
|
||||
return br.Bot.FullStateEvent(ctx, roomID, eventType, "")
|
||||
}
|
||||
|
||||
func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
|
|
@ -687,7 +642,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr
|
|||
if intent != nil {
|
||||
intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp)
|
||||
}
|
||||
if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction {
|
||||
if evt.Type != event.EventEncrypted {
|
||||
err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ import (
|
|||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/database"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
|
@ -38,9 +37,9 @@ func init() {
|
|||
|
||||
var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil)
|
||||
|
||||
var NoSessionFound = crypto.ErrNoSessionFound
|
||||
var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex
|
||||
var UnknownMessageIndex = olm.ErrUnknownMessageIndex
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
var DuplicateMessageIndex = crypto.DuplicateMessageIndex
|
||||
var UnknownMessageIndex = olm.UnknownMessageIndex
|
||||
|
||||
type CryptoHelper struct {
|
||||
bridge *Connector
|
||||
|
|
@ -136,19 +135,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
if isExistingDevice {
|
||||
if !helper.verifyKeysAreOnServer(ctx) {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
err = helper.ShareKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to share device keys: %w", err)
|
||||
}
|
||||
}
|
||||
if helper.bridge.Config.Encryption.SelfSign {
|
||||
if !helper.doSelfSign(ctx) {
|
||||
os.Exit(34)
|
||||
}
|
||||
helper.verifyKeysAreOnServer(ctx)
|
||||
}
|
||||
|
||||
go helper.resyncEncryptionInfo(context.TODO())
|
||||
|
|
@ -156,46 +143,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool {
|
||||
log := zerolog.Ctx(ctx)
|
||||
hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx)
|
||||
if err != nil {
|
||||
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status")
|
||||
return false
|
||||
}
|
||||
log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status")
|
||||
keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey)
|
||||
if !hasKeys || keyInDB == "overwrite" {
|
||||
if keyInDB != "" && keyInDB != "overwrite" {
|
||||
log.WithLevel(zerolog.FatalLevel).
|
||||
Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.")
|
||||
return false
|
||||
}
|
||||
recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx)
|
||||
if recoveryKey != "" {
|
||||
helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey)
|
||||
}
|
||||
if err != nil {
|
||||
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign")
|
||||
return false
|
||||
}
|
||||
log.Info().Msg("Generated new recovery key and self-signed bot device")
|
||||
} else if !isVerified {
|
||||
if keyInDB == "" {
|
||||
log.WithLevel(zerolog.FatalLevel).
|
||||
Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.")
|
||||
return false
|
||||
}
|
||||
err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB)
|
||||
if err != nil {
|
||||
log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key")
|
||||
return false
|
||||
}
|
||||
log.Info().Msg("Verified bot device with existing recovery key")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
|
||||
log := helper.log.With().Str("action", "resync encryption event").Logger()
|
||||
rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
|
||||
|
|
@ -210,12 +157,12 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
|
|||
var evt event.EncryptionEventContent
|
||||
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event")
|
||||
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
|
||||
_, err = helper.store.DB.Exec(ctx, `
|
||||
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
|
||||
`, roomID)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync")
|
||||
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync")
|
||||
}
|
||||
} else {
|
||||
maxAge := evt.RotationPeriodMillis
|
||||
|
|
@ -238,9 +185,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
|
|||
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
|
||||
`, maxAge, maxMessages, roomID)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table")
|
||||
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table")
|
||||
} else {
|
||||
log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table")
|
||||
log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -286,7 +233,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
|
|||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to find existing device ID: %w", err)
|
||||
} else if len(deviceID) > 0 {
|
||||
helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database")
|
||||
helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
|
||||
}
|
||||
// Create a new client instance with the default AS settings (including as_token),
|
||||
// the Login call will then override the access token in the client.
|
||||
|
|
@ -327,7 +274,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool
|
|||
return client, deviceID != "", nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
|
||||
func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
|
||||
helper.log.Debug().Msg("Making sure keys are still on server")
|
||||
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
|
||||
|
|
@ -340,11 +287,10 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool {
|
|||
}
|
||||
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
|
||||
if ok && len(device.Keys) > 0 {
|
||||
return true
|
||||
return
|
||||
}
|
||||
helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto")
|
||||
helper.Reset(ctx, false)
|
||||
return false
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Start() {
|
||||
|
|
@ -439,7 +385,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy
|
|||
var encrypted *event.EncryptedEventContent
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) {
|
||||
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().Err(err).
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ package matrix
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
|
@ -28,7 +27,6 @@ import (
|
|||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
|
|
@ -45,13 +43,13 @@ type ASIntent struct {
|
|||
|
||||
var _ bridgev2.MatrixAPI = (*ASIntent)(nil)
|
||||
var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil)
|
||||
var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil)
|
||||
|
||||
func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) {
|
||||
if extra == nil {
|
||||
extra = &bridgev2.MatrixSendExtra{}
|
||||
}
|
||||
if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) {
|
||||
// TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions
|
||||
if eventType == event.EventRedaction {
|
||||
parsedContent := content.Parsed.(*event.RedactionEventContent)
|
||||
as.Matrix.AddDoublePuppetValue(content)
|
||||
return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{
|
||||
|
|
@ -59,7 +57,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
|
|||
Extra: content.Raw,
|
||||
})
|
||||
}
|
||||
if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction {
|
||||
if eventType != event.EventReaction && eventType != event.EventRedaction {
|
||||
msgContent, ok := content.Parsed.(*event.MessageEventContent)
|
||||
if ok {
|
||||
msgContent.AddPerMessageProfileFallback()
|
||||
|
|
@ -84,27 +82,16 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType
|
|||
eventType = event.EventEncrypted
|
||||
}
|
||||
}
|
||||
return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()})
|
||||
}
|
||||
|
||||
func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) {
|
||||
if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) {
|
||||
return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support")
|
||||
if extra.Timestamp.IsZero() {
|
||||
return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content)
|
||||
} else {
|
||||
return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli())
|
||||
}
|
||||
if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil {
|
||||
return nil, fmt.Errorf("failed to check if room is encrypted: %w", err)
|
||||
} else if encrypted && as.Connector.Crypto != nil {
|
||||
if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventType = event.EventEncrypted
|
||||
}
|
||||
return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID})
|
||||
}
|
||||
|
||||
func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) {
|
||||
targetContent, ok := content.Parsed.(*event.MemberEventContent)
|
||||
if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" {
|
||||
targetContent := content.Parsed.(*event.MemberEventContent)
|
||||
if targetContent.Displayname != "" || targetContent.AvatarURL != "" {
|
||||
return
|
||||
}
|
||||
memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID)
|
||||
|
|
@ -139,7 +126,11 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e
|
|||
if eventType == event.StateMember {
|
||||
as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content)
|
||||
}
|
||||
resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()})
|
||||
if ts.IsZero() {
|
||||
resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content)
|
||||
} else {
|
||||
resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli())
|
||||
}
|
||||
if err != nil && eventType == event.StateMember {
|
||||
var httpErr mautrix.HTTPError
|
||||
if errors.As(err, &httpErr) && httpErr.RespError != nil &&
|
||||
|
|
@ -421,7 +412,6 @@ func (as *ASIntent) UploadMediaStream(
|
|||
removeAndClose(replFile)
|
||||
removeAndClose(tempFile)
|
||||
}
|
||||
req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
|
||||
startedAsyncUpload = true
|
||||
var resp *mautrix.RespCreateMXC
|
||||
resp, err = as.Matrix.UploadAsync(ctx, req)
|
||||
|
|
@ -454,7 +444,6 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn
|
|||
as.Connector.uploadSema.Release(int64(len(req.ContentBytes)))
|
||||
}
|
||||
}
|
||||
req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx)
|
||||
var resp *mautrix.RespCreateMXC
|
||||
resp, err = as.Matrix.UploadAsync(ctx, req)
|
||||
if resp != nil {
|
||||
|
|
@ -486,62 +475,11 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr
|
|||
return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL)
|
||||
}
|
||||
|
||||
func dataToFields(data any) (map[string]json.RawMessage, error) {
|
||||
fields, ok := data.(map[string]json.RawMessage)
|
||||
if ok {
|
||||
return fields, nil
|
||||
}
|
||||
d, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d = canonicaljson.CanonicalJSONAssumeValid(d)
|
||||
err = json.Unmarshal(d, &fields)
|
||||
return fields, err
|
||||
}
|
||||
|
||||
func marshalField(val any) json.RawMessage {
|
||||
data, _ := json.Marshal(val)
|
||||
if len(data) > 0 && (data[0] == '{' || data[0] == '[') {
|
||||
return canonicaljson.CanonicalJSONAssumeValid(data)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
var nullJSON = json.RawMessage("null")
|
||||
|
||||
func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error {
|
||||
if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
|
||||
return as.Matrix.BeeperUpdateProfile(ctx, data)
|
||||
} else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo {
|
||||
fields, err := dataToFields(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal fields: %w", err)
|
||||
}
|
||||
currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current profile: %w", err)
|
||||
}
|
||||
for key, val := range fields {
|
||||
existing, ok := currentProfile.Extra[key]
|
||||
if !ok {
|
||||
if bytes.Equal(val, nullJSON) {
|
||||
continue
|
||||
}
|
||||
err = as.Matrix.SetProfileField(ctx, key, val)
|
||||
} else if !bytes.Equal(marshalField(existing), val) {
|
||||
if bytes.Equal(val, nullJSON) {
|
||||
err = as.Matrix.DeleteProfileField(ctx, key)
|
||||
} else {
|
||||
err = as.Matrix.SetProfileField(ctx, key, val)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set profile field %q: %w", key, err)
|
||||
}
|
||||
}
|
||||
if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return as.Matrix.BeeperUpdateProfile(ctx, data)
|
||||
}
|
||||
|
||||
func (as *ASIntent) GetMXID() id.UserID {
|
||||
|
|
@ -583,39 +521,6 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent {
|
|||
return content
|
||||
}
|
||||
|
||||
func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) {
|
||||
if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
|
||||
// Hungryserv doesn't override the capabilities endpoint nor do room versions
|
||||
return
|
||||
}
|
||||
caps := as.Connector.fetchCapabilities(ctx)
|
||||
roomVer := req.RoomVersion
|
||||
if roomVer == "" && caps != nil && caps.RoomVersions != nil {
|
||||
roomVer = id.RoomVersion(caps.RoomVersions.Default)
|
||||
}
|
||||
if roomVer != "" && !roomVer.PrivilegedRoomCreators() {
|
||||
return
|
||||
}
|
||||
creators, _ := req.CreationContent["additional_creators"].([]id.UserID)
|
||||
creators = append(slices.Clone(creators), as.GetMXID())
|
||||
if req.PowerLevelOverride != nil {
|
||||
for _, creator := range creators {
|
||||
delete(req.PowerLevelOverride.Users, creator)
|
||||
}
|
||||
}
|
||||
for _, evt := range req.InitialState {
|
||||
if evt.Type != event.StatePowerLevels {
|
||||
continue
|
||||
}
|
||||
content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent)
|
||||
if ok {
|
||||
for _, creator := range creators {
|
||||
delete(content.Users, creator)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) {
|
||||
if as.Connector.Config.Encryption.Default {
|
||||
req.InitialState = append(req.InitialState, &event.Event{
|
||||
|
|
@ -631,7 +536,6 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom)
|
|||
}
|
||||
req.CreationContent["m.federate"] = false
|
||||
}
|
||||
as.filterCreateRequestForV12(ctx, req)
|
||||
resp, err := as.Matrix.CreateRoom(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -673,9 +577,6 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id.
|
|||
}
|
||||
|
||||
func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error {
|
||||
if roomID == "" {
|
||||
return nil
|
||||
}
|
||||
if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) {
|
||||
err := as.Matrix.BeeperDeleteRoom(ctx, roomID)
|
||||
if err != nil {
|
||||
|
|
@ -773,23 +674,3 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) {
|
||||
evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content")
|
||||
}
|
||||
|
||||
if evt.Type == event.EventEncrypted {
|
||||
if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt {
|
||||
return nil, errors.New("can't decrypt the event")
|
||||
}
|
||||
return as.Connector.Crypto.Decrypt(ctx, evt)
|
||||
}
|
||||
|
||||
return evt, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,11 +27,6 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) {
|
|||
if br.shouldIgnoreEvent(evt) {
|
||||
return
|
||||
}
|
||||
if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember {
|
||||
zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events")
|
||||
br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
|
||||
return
|
||||
}
|
||||
if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require {
|
||||
zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required")
|
||||
br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true)
|
||||
|
|
@ -68,10 +63,6 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event)
|
|||
case event.EphemeralEventTyping:
|
||||
typingContent := evt.Content.AsTyping()
|
||||
typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser)
|
||||
case event.BeeperEphemeralEventAIStream:
|
||||
if br.shouldIgnoreEvent(evt) {
|
||||
return
|
||||
}
|
||||
}
|
||||
br.Bridge.QueueMatrixEvent(ctx, evt)
|
||||
}
|
||||
|
|
@ -85,11 +76,6 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
|
|||
Str("event_id", evt.ID.String()).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Logger()
|
||||
if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents {
|
||||
log.Debug().Msg("Dropping event from user with no permission to send events")
|
||||
br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt))
|
||||
return
|
||||
}
|
||||
ctx = log.WithContext(ctx)
|
||||
if br.Crypto == nil {
|
||||
br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true)
|
||||
|
|
@ -101,18 +87,17 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
|
|||
decryptionStart := time.Now()
|
||||
decrypted, err := br.Crypto.Decrypt(ctx, evt)
|
||||
decryptionRetryCount := 0
|
||||
var errorEventID id.EventID
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
decryptionRetryCount = 1
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false)
|
||||
go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false)
|
||||
if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err = br.Crypto.Decrypt(ctx, evt)
|
||||
} else {
|
||||
go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID)
|
||||
go br.waitLongerForSession(ctx, evt, decryptionStart)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -121,18 +106,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event)
|
|||
go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true)
|
||||
return
|
||||
}
|
||||
br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart))
|
||||
br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart))
|
||||
}
|
||||
|
||||
func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) {
|
||||
func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
content := evt.Content.AsEncrypted()
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
|
||||
go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
var errorEventID *id.EventID
|
||||
go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false)
|
||||
|
||||
if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
|
|
@ -235,6 +220,7 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event
|
|||
go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount)
|
||||
decrypted.Mautrix.CheckpointSent = true
|
||||
decrypted.Mautrix.DecryptionDuration = duration
|
||||
decrypted.Mautrix.EventSource |= event.SourceDecrypted
|
||||
br.EventProcessor.Dispatch(ctx, decrypted)
|
||||
if errorEventID != nil && *errorEventID != "" {
|
||||
_, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID)
|
||||
|
|
|
|||
|
|
@ -66,12 +66,7 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s
|
|||
} else if errors.Is(err, dbutil.ErrForeignTables) {
|
||||
br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info")
|
||||
} else if errors.Is(err, dbutil.ErrNotOwned) {
|
||||
var noe dbutil.NotOwnedError
|
||||
if errors.As(err, &noe) && noe.Owner == br.Name {
|
||||
br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?")
|
||||
} else {
|
||||
br.Log.Info().Msg("Sharing the same database with different programs is not supported")
|
||||
}
|
||||
br.Log.Info().Msg("Sharing the same database with different programs is not supported")
|
||||
} else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) {
|
||||
br.Log.Info().Msg("Downgrading the bridge is not supported")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,161 +0,0 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mxmain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/util/random"
|
||||
)
|
||||
|
||||
var randomParseFilePrefix = random.String(16) + "READFILE:"
|
||||
|
||||
func parseEnv(prefix string) iter.Seq2[[]string, string] {
|
||||
return func(yield func([]string, string) bool) {
|
||||
for _, s := range os.Environ() {
|
||||
if !strings.HasPrefix(s, prefix) {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(s, "=", 2)
|
||||
key := strings.TrimPrefix(kv[0], prefix)
|
||||
value := kv[1]
|
||||
if strings.HasSuffix(key, "_FILE") {
|
||||
key = strings.TrimSuffix(key, "_FILE")
|
||||
value = randomParseFilePrefix + value
|
||||
}
|
||||
key = strings.ToLower(key)
|
||||
if !strings.ContainsRune(key, '.') {
|
||||
key = strings.ReplaceAll(key, "__", ".")
|
||||
}
|
||||
if !yield(strings.Split(key, "."), value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func reflectYAMLFieldName(f *reflect.StructField) string {
|
||||
parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2)
|
||||
fieldName := parts[0]
|
||||
if fieldName == "-" && len(parts) == 1 {
|
||||
return ""
|
||||
}
|
||||
if fieldName == "" {
|
||||
return strings.ToLower(f.Name)
|
||||
}
|
||||
return fieldName
|
||||
}
|
||||
|
||||
type reflectGetResult struct {
|
||||
val reflect.Value
|
||||
valKind reflect.Kind
|
||||
remainingPath []string
|
||||
}
|
||||
|
||||
func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) {
|
||||
if len(path) == 0 {
|
||||
return &reflectGetResult{val: rv, valKind: rv.Kind()}, true
|
||||
}
|
||||
if rv.Kind() == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
switch rv.Kind() {
|
||||
case reflect.Map:
|
||||
return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true
|
||||
case reflect.Struct:
|
||||
fields := reflect.VisibleFields(rv.Type())
|
||||
for _, field := range fields {
|
||||
fieldName := reflectYAMLFieldName(&field)
|
||||
if fieldName != "" && fieldName == path[0] {
|
||||
return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:])
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) {
|
||||
if len(path) > 0 && path[0] == "network" {
|
||||
return reflectGetYAML(network, path[1:])
|
||||
}
|
||||
return reflectGetYAML(main, path)
|
||||
}
|
||||
|
||||
func formatKeyString(key []string) string {
|
||||
return strings.Join(key, "->")
|
||||
}
|
||||
|
||||
func UpdateConfigFromEnv(cfg, networkData any, prefix string) error {
|
||||
cfgVal := reflect.ValueOf(cfg)
|
||||
networkVal := reflect.ValueOf(networkData)
|
||||
for key, value := range parseEnv(prefix) {
|
||||
field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key)
|
||||
if !ok {
|
||||
return fmt.Errorf("%s not found", formatKeyString(key))
|
||||
}
|
||||
if strings.HasPrefix(value, randomParseFilePrefix) {
|
||||
filepath := strings.TrimPrefix(value, randomParseFilePrefix)
|
||||
fileData, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err)
|
||||
}
|
||||
value = strings.TrimSpace(string(fileData))
|
||||
}
|
||||
var parsedVal any
|
||||
var err error
|
||||
switch field.valKind {
|
||||
case reflect.String:
|
||||
parsedVal = value
|
||||
case reflect.Bool:
|
||||
parsedVal, err = strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
parsedVal, err = strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
parsedVal, err = strconv.ParseUint(value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
parsedVal, err = strconv.ParseFloat(value, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key))
|
||||
}
|
||||
if field.val.Kind() == reflect.Ptr {
|
||||
if field.val.IsNil() {
|
||||
field.val.Set(reflect.New(field.val.Type().Elem()))
|
||||
}
|
||||
field.val = field.val.Elem()
|
||||
}
|
||||
if field.val.Kind() == reflect.Map {
|
||||
key = key[:len(key)-len(field.remainingPath)]
|
||||
mapKeyStr := strings.Join(field.remainingPath, ".")
|
||||
key = append(key, mapKeyStr)
|
||||
if field.val.Type().Key().Kind() != reflect.String {
|
||||
return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key))
|
||||
}
|
||||
field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal))
|
||||
} else {
|
||||
field.val.Set(reflect.ValueOf(parsedVal))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -15,7 +15,6 @@ bridge:
|
|||
# By default, users who are in the same group on the remote network will be
|
||||
# in the same Matrix room bridged to that group. If this is set to true,
|
||||
# every user will get their own Matrix room instead.
|
||||
# SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST.
|
||||
split_portals: false
|
||||
# Should the bridge resend `m.bridge` events to all portals on startup?
|
||||
resend_bridge_info: false
|
||||
|
|
@ -29,9 +28,6 @@ bridge:
|
|||
# How long after an unknown error should the bridge attempt a full reconnect?
|
||||
# Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value.
|
||||
unknown_error_auto_reconnect: null
|
||||
# Maximum number of times to do the auto-reconnect above.
|
||||
# The counter is per login, but is never reset except on logout and restart.
|
||||
unknown_error_max_auto_reconnects: 10
|
||||
|
||||
# Should leaving Matrix rooms be bridged as leaving groups on the remote network?
|
||||
bridge_matrix_leave: false
|
||||
|
|
@ -50,11 +46,6 @@ bridge:
|
|||
# Should cross-room reply metadata be bridged?
|
||||
# Most Matrix clients don't support this and servers may reject such messages too.
|
||||
cross_room_replies: false
|
||||
# If a state event fails to bridge, should the bridge revert any state changes made by that event?
|
||||
revert_failed_state_changes: false
|
||||
# In portals with no relay set, should Matrix users be kicked if they're
|
||||
# not logged into an account that's in the remote chat?
|
||||
kick_matrix_users: true
|
||||
|
||||
# What should be done to portal rooms when a user logs out or is logged out?
|
||||
# Permitted values:
|
||||
|
|
@ -244,9 +235,6 @@ matrix:
|
|||
# The threshold as bytes after which the bridge should roundtrip uploads via the disk
|
||||
# rather than keeping the whole file in memory.
|
||||
upload_file_threshold: 5242880
|
||||
# Should the bridge set additional custom profile info for ghosts?
|
||||
# This can make a lot of requests, as there's no batch profile update endpoint.
|
||||
ghost_extra_profile_info: false
|
||||
|
||||
# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors.
|
||||
analytics:
|
||||
|
|
@ -259,8 +247,10 @@ analytics:
|
|||
|
||||
# Settings for provisioning API
|
||||
provisioning:
|
||||
# Prefix for the provisioning API paths.
|
||||
prefix: /_matrix/provision
|
||||
# Shared secret for authentication. If set to "generate" or null, a random secret will be generated,
|
||||
# or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters.
|
||||
# or if set to "disable", the provisioning API will be disabled.
|
||||
shared_secret: generate
|
||||
# Whether to allow provisioning API requests to be authed using Matrix access tokens.
|
||||
# This follows the same rules as double puppeting to determine which server to contact to check the token,
|
||||
|
|
@ -286,14 +276,6 @@ public_media:
|
|||
expiry: 0
|
||||
# Length of hash to use for public media URLs. Must be between 0 and 32.
|
||||
hash_length: 32
|
||||
# The path prefix for generated URLs. Note that this will NOT change the path where media is actually served.
|
||||
# If you change this, you must configure your reverse proxy to rewrite the path accordingly.
|
||||
path_prefix: /_mautrix/publicmedia
|
||||
# Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs?
|
||||
# If false, the generated URLs will just have the MXC URI and a HMAC signature.
|
||||
# The hash_length field will be used to decide the length of the generated URL.
|
||||
# This also allows invalidating URLs by deleting the database entry.
|
||||
use_database: false
|
||||
|
||||
# Settings for converting remote media to custom mxc:// URIs instead of reuploading.
|
||||
# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html
|
||||
|
|
@ -384,12 +366,6 @@ encryption:
|
|||
# Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861).
|
||||
# Changing this option requires updating the appservice registration file.
|
||||
msc4190: false
|
||||
# Whether to encrypt reactions and reply metadata as per MSC4392.
|
||||
msc4392: false
|
||||
# Should the bridge bot generate a recovery key and cross-signing keys and verify itself?
|
||||
# Note that without the latest version of MSC4190, this will fail if you reset the bridge database.
|
||||
# The generated recovery key will be saved in the kv_store table under `recovery_key`.
|
||||
self_sign: false
|
||||
# Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
|
||||
# You must use a client that supports requesting keys from other users to use this feature.
|
||||
allow_key_sharing: true
|
||||
|
|
@ -452,16 +428,6 @@ encryption:
|
|||
# You should not enable this option unless you understand all the implications.
|
||||
disable_device_change_key_rotation: false
|
||||
|
||||
# Prefix for environment variables. All variables with this prefix must map to valid config fields.
|
||||
# Nesting in variable names is represented with a dot (.).
|
||||
# If there are no dots in the name, two underscores (__) are replaced with a dot.
|
||||
#
|
||||
# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token.
|
||||
# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily.
|
||||
#
|
||||
# If this is null, reading config fields from environment will be disabled.
|
||||
env_config_prefix: null
|
||||
|
||||
# Logging config. See https://github.com/tulir/zeroconfig for details.
|
||||
logging:
|
||||
min_level: debug
|
||||
|
|
|
|||
|
|
@ -135,10 +135,7 @@ func (br *BridgeMain) CheckLegacyDB(
|
|||
}
|
||||
var dbVersion int
|
||||
err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to get database version")
|
||||
return
|
||||
} else if dbVersion < expectedVersion {
|
||||
if dbVersion < expectedVersion {
|
||||
log.Fatal().
|
||||
Int("expected_version", expectedVersion).
|
||||
Int("version", dbVersion).
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ import (
|
|||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exzerolog"
|
||||
"go.mau.fi/util/progver"
|
||||
"gopkg.in/yaml.v3"
|
||||
flag "maunium.net/go/mauflag"
|
||||
|
||||
|
|
@ -63,9 +62,6 @@ type BridgeMain struct {
|
|||
// git tag to see if the built version is the release or a dev build.
|
||||
// You can either bump this right after a release or right before, as long as it matches on the release commit.
|
||||
Version string
|
||||
// SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning,
|
||||
// such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch.
|
||||
SemCalVer bool
|
||||
|
||||
// PostInit is a function that will be called after the bridge has been initialized but before it is started.
|
||||
PostInit func()
|
||||
|
|
@ -90,7 +86,11 @@ type BridgeMain struct {
|
|||
RegistrationPath string
|
||||
SaveConfig bool
|
||||
|
||||
ver progver.ProgramVersion
|
||||
baseVersion string
|
||||
commit string
|
||||
LinkifiedVersion string
|
||||
VersionDesc string
|
||||
BuildTime time.Time
|
||||
|
||||
AdditionalShortFlags string
|
||||
AdditionalLongFlags string
|
||||
|
|
@ -99,7 +99,14 @@ type BridgeMain struct {
|
|||
}
|
||||
|
||||
type VersionJSONOutput struct {
|
||||
progver.ProgramVersion
|
||||
Name string
|
||||
URL string
|
||||
|
||||
Version string
|
||||
IsRelease bool
|
||||
Commit string
|
||||
FormattedVersion string
|
||||
BuildTime time.Time
|
||||
|
||||
OS string
|
||||
Arch string
|
||||
|
|
@ -140,11 +147,18 @@ func (br *BridgeMain) PreInit() {
|
|||
flag.PrintHelp()
|
||||
os.Exit(0)
|
||||
} else if *version {
|
||||
fmt.Println(br.ver.VersionDescription)
|
||||
fmt.Println(br.VersionDesc)
|
||||
os.Exit(0)
|
||||
} else if *versionJSON {
|
||||
output := VersionJSONOutput{
|
||||
ProgramVersion: br.ver,
|
||||
URL: br.URL,
|
||||
Name: br.Name,
|
||||
|
||||
Version: br.baseVersion,
|
||||
IsRelease: br.Version == br.baseVersion,
|
||||
Commit: br.commit,
|
||||
FormattedVersion: br.Version,
|
||||
BuildTime: br.BuildTime,
|
||||
|
||||
OS: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
|
|
@ -226,8 +240,8 @@ func (br *BridgeMain) Init() {
|
|||
|
||||
br.Log.Info().
|
||||
Str("name", br.Name).
|
||||
Str("version", br.ver.FormattedVersion).
|
||||
Time("built_at", br.ver.BuildTime).
|
||||
Str("version", br.Version).
|
||||
Time("built_at", br.BuildTime).
|
||||
Str("go_version", runtime.Version()).
|
||||
Msg("Initializing bridge")
|
||||
|
||||
|
|
@ -241,7 +255,7 @@ func (br *BridgeMain) Init() {
|
|||
br.Matrix.AS.DoublePuppetValue = br.Name
|
||||
br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{
|
||||
Func: func(ce *commands.Event) {
|
||||
ce.Reply(br.ver.MarkdownDescription())
|
||||
ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123))
|
||||
},
|
||||
Name: "version",
|
||||
Help: commands.HelpMeta{
|
||||
|
|
@ -354,13 +368,6 @@ func (br *BridgeMain) LoadConfig() {
|
|||
}
|
||||
}
|
||||
cfg.Bridge.Backfill = cfg.Backfill
|
||||
if cfg.EnvConfigPrefix != "" {
|
||||
err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err)
|
||||
os.Exit(10)
|
||||
}
|
||||
}
|
||||
br.Config = &cfg
|
||||
}
|
||||
|
||||
|
|
@ -439,12 +446,42 @@ func (br *BridgeMain) Stop() {
|
|||
//
|
||||
// (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`)
|
||||
func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) {
|
||||
br.ver = progver.ProgramVersion{
|
||||
Name: br.Name,
|
||||
URL: br.URL,
|
||||
BaseVersion: br.Version,
|
||||
SemCalVer: br.SemCalVer,
|
||||
}.Init(tag, commit, rawBuildTime)
|
||||
mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent)
|
||||
br.Version = br.ver.FormattedVersion
|
||||
br.baseVersion = br.Version
|
||||
if len(tag) > 0 && tag[0] == 'v' {
|
||||
tag = tag[1:]
|
||||
}
|
||||
if tag != br.Version {
|
||||
suffix := ""
|
||||
if !strings.HasSuffix(br.Version, "+dev") {
|
||||
suffix = "+dev"
|
||||
}
|
||||
if len(commit) > 8 {
|
||||
br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8])
|
||||
} else {
|
||||
br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix)
|
||||
}
|
||||
}
|
||||
|
||||
br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version)
|
||||
if tag == br.Version {
|
||||
br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag)
|
||||
} else if len(commit) > 8 {
|
||||
br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1)
|
||||
}
|
||||
var buildTime time.Time
|
||||
if rawBuildTime != "unknown" {
|
||||
buildTime, _ = time.Parse(time.RFC3339, rawBuildTime)
|
||||
}
|
||||
var builtWith string
|
||||
if buildTime.IsZero() {
|
||||
rawBuildTime = "unknown"
|
||||
builtWith = runtime.Version()
|
||||
} else {
|
||||
rawBuildTime = buildTime.Format(time.RFC1123)
|
||||
builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version())
|
||||
}
|
||||
mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent)
|
||||
br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith)
|
||||
br.commit = commit
|
||||
br.BuildTime = buildTime
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ import (
|
|||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/bridgev2/provisionutil"
|
||||
"maunium.net/go/mautrix/bridgev2/status"
|
||||
"maunium.net/go/mautrix/federation"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
|
@ -85,9 +84,10 @@ const (
|
|||
provisioningUserKey provisioningContextKey = iota
|
||||
provisioningUserLoginKey
|
||||
provisioningLoginProcessKey
|
||||
ProvisioningKeyRequest
|
||||
)
|
||||
|
||||
const ProvisioningKeyRequest = "fi.mau.provision.request"
|
||||
|
||||
func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
|
||||
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
|
||||
}
|
||||
|
|
@ -96,7 +96,12 @@ func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
|
|||
return prov.Router
|
||||
}
|
||||
|
||||
func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI {
|
||||
type IProvisioningAPI interface {
|
||||
GetRouter() *http.ServeMux
|
||||
GetUser(r *http.Request) *bridgev2.User
|
||||
}
|
||||
|
||||
func (br *Connector) GetProvisioning() IProvisioningAPI {
|
||||
return br.Provisioning
|
||||
}
|
||||
|
||||
|
|
@ -114,7 +119,6 @@ func (prov *ProvisioningAPI) Init() {
|
|||
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
|
||||
prov.Router = http.NewServeMux()
|
||||
prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami)
|
||||
prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities)
|
||||
prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows)
|
||||
prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
|
||||
prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep)
|
||||
|
|
@ -124,7 +128,7 @@ func (prov *ProvisioningAPI) Init() {
|
|||
prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
|
||||
prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
|
||||
prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
|
||||
prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup)
|
||||
prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
|
||||
|
||||
if prov.br.Config.Provisioning.EnableSessionTransfers {
|
||||
prov.log.Debug().Msg("Enabling session transfer API")
|
||||
|
|
@ -206,20 +210,12 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI
|
|||
}
|
||||
}
|
||||
|
||||
func disabledAuth(w http.ResponseWriter, r *http.Request) {
|
||||
mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w)
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
|
||||
secret := prov.br.Config.Provisioning.SharedSecret
|
||||
if len(secret) < 16 {
|
||||
return http.HandlerFunc(disabledAuth)
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
if auth == "" {
|
||||
mautrix.MMissingToken.WithMessage("Missing auth token").Write(w)
|
||||
} else if !exstrings.ConstantTimeEqual(auth, secret) {
|
||||
} else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
|
||||
mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w)
|
||||
} else {
|
||||
h.ServeHTTP(w, r)
|
||||
|
|
@ -228,10 +224,6 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
||||
secret := prov.br.Config.Provisioning.SharedSecret
|
||||
if len(secret) < 16 {
|
||||
return http.HandlerFunc(disabledAuth)
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
if auth == "" && prov.GetAuthFromRequest != nil {
|
||||
|
|
@ -245,7 +237,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
|
|||
if userID == "" && prov.GetUserIDFromRequest != nil {
|
||||
userID = prov.GetUserIDFromRequest(r)
|
||||
}
|
||||
if !exstrings.ConstantTimeEqual(auth, secret) {
|
||||
if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) {
|
||||
var err error
|
||||
if strings.HasPrefix(auth, "openid:") {
|
||||
err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:"))
|
||||
|
|
@ -324,7 +316,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) {
|
|||
prevState.UserID = ""
|
||||
prevState.RemoteID = ""
|
||||
prevState.RemoteName = ""
|
||||
prevState.RemoteProfile = status.RemoteProfile{}
|
||||
prevState.RemoteProfile = nil
|
||||
resp.Logins[i] = RespWhoamiLogin{
|
||||
StateEvent: prevState.StateEvent,
|
||||
StateTS: prevState.Timestamp,
|
||||
|
|
@ -356,24 +348,18 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques
|
|||
})
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) {
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning)
|
||||
}
|
||||
|
||||
var ErrNilStep = errors.New("bridge returned nil step with no error")
|
||||
var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"}
|
||||
|
||||
func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) {
|
||||
overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r)
|
||||
if failed {
|
||||
return
|
||||
}
|
||||
user := prov.GetUser(r)
|
||||
if overrideLogin == nil && user.HasTooManyLogins() {
|
||||
ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w)
|
||||
return
|
||||
}
|
||||
login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID"))
|
||||
login, err := prov.net.CreateLogin(
|
||||
r.Context(),
|
||||
prov.GetUser(r),
|
||||
r.PathValue("flowID"),
|
||||
)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
|
||||
RespondWithError(w, err, "Internal error creating login process")
|
||||
|
|
@ -403,18 +389,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
|
|||
Override: overrideLogin,
|
||||
}
|
||||
prov.loginsLock.Unlock()
|
||||
zerolog.Ctx(r.Context()).Info().
|
||||
Any("first_step", firstStep).
|
||||
Msg("Created login process")
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep})
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) {
|
||||
zerolog.Ctx(ctx).Info().
|
||||
Str("step_id", step.StepID).
|
||||
Str("user_login_id", string(step.CompleteParams.UserLoginID)).
|
||||
Msg("Login completed successfully")
|
||||
prov.deleteLogin(login, false)
|
||||
if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID {
|
||||
return
|
||||
}
|
||||
|
|
@ -428,15 +406,6 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
|
|||
}, bridgev2.DeleteOpts{LogoutRemote: true})
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) {
|
||||
if cancel {
|
||||
login.Process.Cancel()
|
||||
}
|
||||
prov.loginsLock.Lock()
|
||||
delete(prov.logins, login.ID)
|
||||
prov.loginsLock.Unlock()
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) {
|
||||
loginID := r.PathValue("loginProcessID")
|
||||
prov.loginsLock.RLock()
|
||||
|
|
@ -507,14 +476,11 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http
|
|||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input")
|
||||
RespondWithError(w, err, "Internal error submitting input")
|
||||
prov.deleteLogin(login, true)
|
||||
return
|
||||
}
|
||||
login.NextStep = nextStep
|
||||
if nextStep.Type == bridgev2.LoginStepTypeComplete {
|
||||
prov.handleCompleteStep(r.Context(), login, nextStep)
|
||||
} else {
|
||||
zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
|
||||
}
|
||||
|
|
@ -528,14 +494,11 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
|
|||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait")
|
||||
RespondWithError(w, err, "Internal error waiting for login")
|
||||
prov.deleteLogin(login, true)
|
||||
return
|
||||
}
|
||||
login.NextStep = nextStep
|
||||
if nextStep.Type == bridgev2.LoginStepTypeComplete {
|
||||
prov.handleCompleteStep(r.Context(), login, nextStep)
|
||||
} else {
|
||||
zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step")
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep})
|
||||
}
|
||||
|
|
@ -619,23 +582,115 @@ func RespondWithError(w http.ResponseWriter, err error, message string) {
|
|||
}
|
||||
}
|
||||
|
||||
type RespResolveIdentifier struct {
|
||||
ID networkid.UserID `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
|
||||
Identifiers []string `json:"identifiers,omitempty"`
|
||||
MXID id.UserID `json:"mxid,omitempty"`
|
||||
DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) {
|
||||
login := prov.GetLoginForRequest(w, r)
|
||||
if login == nil {
|
||||
return
|
||||
}
|
||||
resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat)
|
||||
api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
|
||||
if !ok {
|
||||
mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w)
|
||||
return
|
||||
}
|
||||
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
|
||||
RespondWithError(w, err, "Internal error resolving identifier")
|
||||
return
|
||||
} else if resp == nil {
|
||||
mautrix.MNotFound.WithMessage("Identifier not found").Write(w)
|
||||
} else {
|
||||
status := http.StatusOK
|
||||
if resp.JustCreated {
|
||||
status = http.StatusCreated
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, status, resp)
|
||||
return
|
||||
}
|
||||
apiResp := &RespResolveIdentifier{
|
||||
ID: resp.UserID,
|
||||
}
|
||||
status := http.StatusOK
|
||||
if resp.Ghost != nil {
|
||||
if resp.UserInfo != nil {
|
||||
resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo)
|
||||
}
|
||||
apiResp.Name = resp.Ghost.Name
|
||||
apiResp.AvatarURL = resp.Ghost.AvatarMXC
|
||||
apiResp.Identifiers = resp.Ghost.Identifiers
|
||||
apiResp.MXID = resp.Ghost.Intent.GetMXID()
|
||||
} else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
|
||||
apiResp.Name = *resp.UserInfo.Name
|
||||
}
|
||||
if resp.Chat != nil {
|
||||
if resp.Chat.Portal == nil {
|
||||
resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal")
|
||||
mautrix.MUnknown.WithMessage("Failed to get portal").Write(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
if createChat && resp.Chat.Portal.MXID == "" {
|
||||
status = http.StatusCreated
|
||||
err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room")
|
||||
mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
apiResp.DMRoomID = resp.Chat.Portal.MXID
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, status, apiResp)
|
||||
}
|
||||
|
||||
type RespGetContactList struct {
|
||||
Contacts []*RespResolveIdentifier `json:"contacts"`
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) {
|
||||
apiResp = make([]*RespResolveIdentifier, len(resp))
|
||||
for i, contact := range resp {
|
||||
apiContact := &RespResolveIdentifier{
|
||||
ID: contact.UserID,
|
||||
}
|
||||
apiResp[i] = apiContact
|
||||
if contact.UserInfo != nil {
|
||||
if contact.UserInfo.Name != nil {
|
||||
apiContact.Name = *contact.UserInfo.Name
|
||||
}
|
||||
if contact.UserInfo.Identifiers != nil {
|
||||
apiContact.Identifiers = contact.UserInfo.Identifiers
|
||||
}
|
||||
}
|
||||
if contact.Ghost != nil {
|
||||
if contact.Ghost.Name != "" {
|
||||
apiContact.Name = contact.Ghost.Name
|
||||
}
|
||||
if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
|
||||
apiContact.Identifiers = contact.Ghost.Identifiers
|
||||
}
|
||||
apiContact.AvatarURL = contact.Ghost.AvatarMXC
|
||||
apiContact.MXID = contact.Ghost.Intent.GetMXID()
|
||||
}
|
||||
if contact.Chat != nil {
|
||||
if contact.Chat.Portal == nil {
|
||||
var err error
|
||||
contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
|
||||
}
|
||||
}
|
||||
if contact.Chat.Portal != nil {
|
||||
apiContact.DMRoomID = contact.Chat.Portal.MXID
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -643,18 +698,30 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque
|
|||
if login == nil {
|
||||
return
|
||||
}
|
||||
resp, err := provisionutil.GetContactList(r.Context(), login)
|
||||
if err != nil {
|
||||
RespondWithError(w, err, "Internal error getting contact list")
|
||||
api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
|
||||
if !ok {
|
||||
mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w)
|
||||
return
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
|
||||
resp, err := api.GetContactList(r.Context())
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
|
||||
RespondWithError(w, err, "Internal error fetching contact list")
|
||||
return
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{
|
||||
Contacts: prov.processResolveIdentifiers(r.Context(), resp),
|
||||
})
|
||||
}
|
||||
|
||||
type ReqSearchUsers struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type RespSearchUsers struct {
|
||||
Results []*RespResolveIdentifier `json:"results"`
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) {
|
||||
var req ReqSearchUsers
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
|
|
@ -667,12 +734,20 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ
|
|||
if login == nil {
|
||||
return
|
||||
}
|
||||
resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query)
|
||||
if err != nil {
|
||||
RespondWithError(w, err, "Internal error searching users")
|
||||
api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
|
||||
if !ok {
|
||||
mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w)
|
||||
return
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
|
||||
resp, err := api.SearchUsers(r.Context(), req.Query)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list")
|
||||
RespondWithError(w, err, "Internal error fetching contact list")
|
||||
return
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{
|
||||
Results: prov.processResolveIdentifiers(r.Context(), resp),
|
||||
})
|
||||
}
|
||||
|
||||
func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -684,24 +759,11 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request
|
|||
}
|
||||
|
||||
func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
var req bridgev2.GroupCreateParams
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body")
|
||||
mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w)
|
||||
return
|
||||
}
|
||||
req.Type = r.PathValue("type")
|
||||
login := prov.GetLoginForRequest(w, r)
|
||||
if login == nil {
|
||||
return
|
||||
}
|
||||
resp, err := provisionutil.CreateGroup(r.Context(), login, &req)
|
||||
if err != nil {
|
||||
RespondWithError(w, err, "Internal error creating group")
|
||||
return
|
||||
}
|
||||
exhttp.WriteJSONResponse(w, http.StatusOK, resp)
|
||||
mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w)
|
||||
}
|
||||
|
||||
type ReqExportCredentials struct {
|
||||
|
|
|
|||
|
|
@ -361,25 +361,14 @@ paths:
|
|||
$ref: '#/components/responses/InternalError'
|
||||
501:
|
||||
$ref: '#/components/responses/NotSupported'
|
||||
/v3/create_group/{type}:
|
||||
/v3/create_group:
|
||||
post:
|
||||
tags: [ snc ]
|
||||
summary: Create a group chat on the remote network.
|
||||
operationId: createGroup
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/loginID"
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/GroupCreateParams'
|
||||
responses:
|
||||
200:
|
||||
description: Identifier resolved successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreatedGroup'
|
||||
401:
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
404:
|
||||
|
|
@ -400,7 +389,7 @@ components:
|
|||
- username
|
||||
- meow@example.com
|
||||
loginID:
|
||||
name: login_id
|
||||
name: loginID
|
||||
in: query
|
||||
description: An optional explicit login ID to do the action through.
|
||||
required: false
|
||||
|
|
@ -583,74 +572,6 @@ components:
|
|||
description: The Matrix room ID of the direct chat with the user.
|
||||
examples:
|
||||
- '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
|
||||
GroupCreateParams:
|
||||
type: object
|
||||
description: |
|
||||
Parameters for creating a group chat.
|
||||
The /capabilities endpoint response must be checked to see which fields are actually allowed.
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type of group to create.
|
||||
examples:
|
||||
- channel
|
||||
username:
|
||||
type: string
|
||||
description: The public username for the created group.
|
||||
participants:
|
||||
type: array
|
||||
description: The users to add to the group initially.
|
||||
items:
|
||||
type: string
|
||||
parent:
|
||||
type: object
|
||||
name:
|
||||
type: object
|
||||
description: The `m.room.name` event content for the room.
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
avatar:
|
||||
type: object
|
||||
description: The `m.room.avatar` event content for the room.
|
||||
properties:
|
||||
url:
|
||||
type: string
|
||||
format: mxc
|
||||
topic:
|
||||
type: object
|
||||
description: The `m.room.topic` event content for the room.
|
||||
properties:
|
||||
topic:
|
||||
type: string
|
||||
disappear:
|
||||
type: object
|
||||
description: The `com.beeper.disappearing_timer` event content for the room.
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
timer:
|
||||
type: number
|
||||
room_id:
|
||||
type: string
|
||||
format: matrix_room_id
|
||||
description: |
|
||||
An existing Matrix room ID to bridge to.
|
||||
The other parameters must be already in sync with the room state when using this parameter.
|
||||
CreatedGroup:
|
||||
type: object
|
||||
description: A successfully created group chat.
|
||||
required: [id, mxid]
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: The internal chat ID of the created group.
|
||||
mxid:
|
||||
type: string
|
||||
format: matrix_room_id
|
||||
description: The Matrix room ID of the portal.
|
||||
examples:
|
||||
- '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io'
|
||||
LoginStep:
|
||||
type: object
|
||||
description: A step in a login process.
|
||||
|
|
@ -714,7 +635,7 @@ components:
|
|||
type:
|
||||
type: string
|
||||
description: The type of field.
|
||||
enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ]
|
||||
enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ]
|
||||
id:
|
||||
type: string
|
||||
description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge.
|
||||
|
|
@ -728,53 +649,10 @@ components:
|
|||
description: A more detailed description of the field shown to the user.
|
||||
examples:
|
||||
- Include the country code with a +
|
||||
default_value:
|
||||
type: string
|
||||
description: A default value that the client can pre-fill the field with.
|
||||
pattern:
|
||||
type: string
|
||||
format: regex
|
||||
description: A regular expression that the field value must match.
|
||||
options:
|
||||
type: array
|
||||
description: For fields of type select, the valid options.
|
||||
items:
|
||||
type: string
|
||||
attachments:
|
||||
type: array
|
||||
description: A list of media attachments to show the user alongside the form fields.
|
||||
items:
|
||||
type: object
|
||||
description: A media attachment to show the user.
|
||||
required: [ type, filename, content ]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported.
|
||||
enum: [ m.image, m.audio ]
|
||||
filename:
|
||||
type: string
|
||||
description: The filename for the media attachment.
|
||||
content:
|
||||
type: string
|
||||
description: The raw file content for the attachment encoded in base64.
|
||||
info:
|
||||
type: object
|
||||
description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted.
|
||||
properties:
|
||||
mimetype:
|
||||
type: string
|
||||
description: The MIME type for the media content.
|
||||
examples: [ image/png, audio/mpeg ]
|
||||
w:
|
||||
type: number
|
||||
description: The width of the media in pixels. Only applicable for images and videos.
|
||||
h:
|
||||
type: number
|
||||
description: The height of the media in pixels. Only applicable for images and videos.
|
||||
size:
|
||||
type: number
|
||||
description: The size of the media content in number of bytes. Strongly recommended to include.
|
||||
- description: Cookie login step
|
||||
required: [ type, cookies ]
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -7,26 +7,16 @@
|
|||
package matrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/database"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
|
@ -43,10 +33,7 @@ func (br *Connector) initPublicMedia() error {
|
|||
return fmt.Errorf("public media hash length is negative")
|
||||
}
|
||||
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
|
||||
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia)
|
||||
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia)
|
||||
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
|
||||
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -57,20 +44,6 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte {
|
|||
return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)]
|
||||
}
|
||||
|
||||
func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte {
|
||||
hasher := hmac.New(sha256.New, br.pubMediaSigKey)
|
||||
hasher.Write([]byte(pm.MXC.String()))
|
||||
hasher.Write([]byte(pm.MimeType))
|
||||
if pm.Keys != nil {
|
||||
hasher.Write([]byte(pm.Keys.Version))
|
||||
hasher.Write([]byte(pm.Keys.Key.Algorithm))
|
||||
hasher.Write([]byte(pm.Keys.Key.Key))
|
||||
hasher.Write([]byte(pm.Keys.InitVector))
|
||||
hasher.Write([]byte(pm.Keys.Hashes.SHA256))
|
||||
}
|
||||
return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength]
|
||||
}
|
||||
|
||||
func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte {
|
||||
var expiresAt []byte
|
||||
if br.Config.PublicMedia.Expiry > 0 {
|
||||
|
|
@ -120,47 +93,9 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, "checksum expired", http.StatusGone)
|
||||
return
|
||||
}
|
||||
br.doProxyMedia(w, r, contentURI, nil, "")
|
||||
}
|
||||
|
||||
func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) {
|
||||
if !br.Config.PublicMedia.UseDatabase {
|
||||
http.Error(w, "public media short links are disabled", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
log := zerolog.Ctx(r.Context())
|
||||
media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID"))
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get public media from database")
|
||||
http.Error(w, "failed to get media metadata", http.StatusInternalServerError)
|
||||
return
|
||||
} else if media == nil {
|
||||
http.Error(w, "media ID not found", http.StatusNotFound)
|
||||
return
|
||||
} else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) {
|
||||
// This is not gone as it can still be refreshed in the DB
|
||||
http.Error(w, "media expired", http.StatusNotFound)
|
||||
return
|
||||
} else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil {
|
||||
http.Error(w, "media keys are malformed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType)
|
||||
}
|
||||
|
||||
var safeMimes = []string{
|
||||
"text/css", "text/plain", "text/csv",
|
||||
"application/json", "application/ld+json",
|
||||
"image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
|
||||
"video/mp4", "video/webm", "video/ogg", "video/quicktime",
|
||||
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
|
||||
"audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac",
|
||||
}
|
||||
|
||||
func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) {
|
||||
resp, err := br.Bot.Download(r.Context(), contentURI)
|
||||
if err != nil {
|
||||
zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
|
||||
br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy")
|
||||
http.Error(w, "failed to download media", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -168,41 +103,11 @@ func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, conten
|
|||
for _, hdr := range proxyHeadersToCopy {
|
||||
w.Header()[hdr] = resp.Header[hdr]
|
||||
}
|
||||
stream := resp.Body
|
||||
if encInfo != nil {
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
contentDisposition := "attachment"
|
||||
if slices.Contains(safeMimes, mimeType) {
|
||||
contentDisposition = "inline"
|
||||
}
|
||||
dispositionArgs := map[string]string{}
|
||||
if filename := r.PathValue("filename"); filename != "" {
|
||||
dispositionArgs["filename"] = filename
|
||||
}
|
||||
w.Header().Set("Content-Type", mimeType)
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs))
|
||||
// Note: this won't check the Close result like it should, but it's probably not a big deal here
|
||||
stream = encInfo.DecryptStream(stream)
|
||||
} else if filename := r.PathValue("filename"); filename != "" {
|
||||
contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
|
||||
if contentDisposition == "" {
|
||||
contentDisposition = "attachment"
|
||||
}
|
||||
w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{
|
||||
"filename": filename,
|
||||
}))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(w, stream)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}
|
||||
|
||||
func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string {
|
||||
return br.getPublicMediaAddressWithFileName(contentURI, "")
|
||||
}
|
||||
|
||||
func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string {
|
||||
if br.pubMediaSigKey == nil {
|
||||
return ""
|
||||
}
|
||||
|
|
@ -210,69 +115,11 @@ func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIS
|
|||
if err != nil || !parsed.IsValid() {
|
||||
return ""
|
||||
}
|
||||
fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_"))
|
||||
if fileName == ".." {
|
||||
fileName = ""
|
||||
}
|
||||
parts := []string{
|
||||
return fmt.Sprintf(
|
||||
"%s/_mautrix/publicmedia/%s/%s/%s",
|
||||
br.GetPublicAddress(),
|
||||
strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
|
||||
parsed.Homeserver,
|
||||
parsed.FileID,
|
||||
base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)),
|
||||
fileName,
|
||||
}
|
||||
if fileName == "" {
|
||||
parts = parts[:len(parts)-1]
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) {
|
||||
if br.pubMediaSigKey == nil {
|
||||
return "", bridgev2.ErrPublicMediaDisabled
|
||||
}
|
||||
if !br.Config.PublicMedia.UseDatabase {
|
||||
if evt.File != nil {
|
||||
return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled)
|
||||
}
|
||||
return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil
|
||||
}
|
||||
mxc := evt.URL
|
||||
var keys *attachment.EncryptedFile
|
||||
if evt.File != nil {
|
||||
mxc = evt.File.URL
|
||||
keys = &evt.File.EncryptedFile
|
||||
}
|
||||
parsedMXC, err := mxc.Parse()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
|
||||
}
|
||||
pm := &database.PublicMedia{
|
||||
MXC: parsedMXC,
|
||||
Keys: keys,
|
||||
MimeType: evt.GetInfo().MimeType,
|
||||
}
|
||||
if br.Config.PublicMedia.Expiry > 0 {
|
||||
pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second)
|
||||
}
|
||||
pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm))
|
||||
err = br.Bridge.DB.PublicMedia.Put(ctx, pm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err)
|
||||
}
|
||||
fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_"))
|
||||
if fileName == ".." {
|
||||
fileName = ""
|
||||
}
|
||||
parts := []string{
|
||||
br.GetPublicAddress(),
|
||||
strings.Trim(br.Config.PublicMedia.PathPrefix, "/"),
|
||||
pm.PublicID,
|
||||
fileName,
|
||||
}
|
||||
if fileName == "" {
|
||||
parts = parts[:len(parts)-1]
|
||||
}
|
||||
return strings.Join(parts, "/"), nil
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/exhttp"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2/database"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
|
|
@ -25,10 +23,8 @@ import (
|
|||
)
|
||||
|
||||
type MatrixCapabilities struct {
|
||||
AutoJoinInvites bool
|
||||
BatchSending bool
|
||||
ArbitraryMemberChange bool
|
||||
ExtraProfileMeta bool
|
||||
AutoJoinInvites bool
|
||||
BatchSending bool
|
||||
}
|
||||
|
||||
type MatrixConnector interface {
|
||||
|
|
@ -62,54 +58,35 @@ type MatrixConnector interface {
|
|||
}
|
||||
|
||||
type MatrixConnectorWithArbitraryRoomState interface {
|
||||
MatrixConnector
|
||||
GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error)
|
||||
}
|
||||
|
||||
type MatrixConnectorWithServer interface {
|
||||
MatrixConnector
|
||||
GetPublicAddress() string
|
||||
GetRouter() *http.ServeMux
|
||||
}
|
||||
|
||||
type IProvisioningAPI interface {
|
||||
GetRouter() *http.ServeMux
|
||||
GetUser(r *http.Request) *User
|
||||
}
|
||||
|
||||
type MatrixConnectorWithProvisioning interface {
|
||||
MatrixConnector
|
||||
GetProvisioning() IProvisioningAPI
|
||||
}
|
||||
|
||||
type MatrixConnectorWithPublicMedia interface {
|
||||
MatrixConnector
|
||||
GetPublicMediaAddress(contentURI id.ContentURIString) string
|
||||
GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error)
|
||||
}
|
||||
|
||||
type MatrixConnectorWithNameDisambiguation interface {
|
||||
MatrixConnector
|
||||
IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error)
|
||||
}
|
||||
|
||||
type MatrixConnectorWithBridgeIdentifier interface {
|
||||
MatrixConnector
|
||||
GetUniqueBridgeID() string
|
||||
}
|
||||
|
||||
type MatrixConnectorWithURLPreviews interface {
|
||||
MatrixConnector
|
||||
GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error)
|
||||
}
|
||||
|
||||
type MatrixConnectorWithPostRoomBridgeHandling interface {
|
||||
MatrixConnector
|
||||
HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error
|
||||
}
|
||||
|
||||
type MatrixConnectorWithAnalytics interface {
|
||||
MatrixConnector
|
||||
TrackAnalytics(userID id.UserID, event string, properties map[string]any)
|
||||
}
|
||||
|
||||
|
|
@ -124,15 +101,9 @@ type DirectNotificationData struct {
|
|||
}
|
||||
|
||||
type MatrixConnectorWithNotifications interface {
|
||||
MatrixConnector
|
||||
DisplayNotification(ctx context.Context, data *DirectNotificationData)
|
||||
}
|
||||
|
||||
type MatrixConnectorWithHTTPSettings interface {
|
||||
MatrixConnector
|
||||
GetHTTPClientSettings() exhttp.ClientSettings
|
||||
}
|
||||
|
||||
type MatrixSendExtra struct {
|
||||
Timestamp time.Time
|
||||
MessageMeta *database.Message
|
||||
|
|
@ -205,21 +176,12 @@ type MatrixAPI interface {
|
|||
|
||||
TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error
|
||||
MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error
|
||||
|
||||
GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error)
|
||||
}
|
||||
|
||||
type StreamOrderReadingMatrixAPI interface {
|
||||
MatrixAPI
|
||||
MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error
|
||||
}
|
||||
|
||||
type MarkAsDMMatrixAPI interface {
|
||||
MatrixAPI
|
||||
MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error
|
||||
}
|
||||
|
||||
type EphemeralSendingMatrixAPI interface {
|
||||
MatrixAPI
|
||||
BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,36 +88,6 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI,
|
|||
rejectInvite(ctx, evt, intent, "")
|
||||
}
|
||||
|
||||
func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) {
|
||||
if portal.MXID == "" {
|
||||
return
|
||||
}
|
||||
log := zerolog.Ctx(ctx)
|
||||
existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("old_portal_mxid", portal.MXID).
|
||||
Msg("Failed to check existing portal members, deleting room")
|
||||
} else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok {
|
||||
log.Debug().
|
||||
Stringer("old_portal_mxid", portal.MXID).
|
||||
Msg("Inviter has no member event in old portal, deleting room")
|
||||
} else if targetUserMember.Membership.IsInviteOrJoin() {
|
||||
return
|
||||
} else {
|
||||
log.Debug().
|
||||
Stringer("old_portal_mxid", portal.MXID).
|
||||
Str("membership", string(targetUserMember.Membership)).
|
||||
Msg("Inviter is not in old portal, deleting room")
|
||||
}
|
||||
|
||||
if err = portal.RemoveMXID(ctx); err != nil {
|
||||
log.Err(err).Msg("Failed to delete old portal mxid")
|
||||
} else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
|
||||
log.Err(err).Msg("Failed to clean up old portal room")
|
||||
}
|
||||
}
|
||||
|
||||
func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult {
|
||||
ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey()))
|
||||
validator, ok := br.Network.(IdentifierValidatingNetwork)
|
||||
|
|
@ -195,7 +165,34 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
|
|||
return EventHandlingResultFailed
|
||||
}
|
||||
}
|
||||
portal.CleanupOrphanedDM(ctx, sender.MXID)
|
||||
if portal.MXID != "" {
|
||||
doCleanup := true
|
||||
existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("old_portal_mxid", portal.MXID).
|
||||
Msg("Failed to check existing portal members, deleting room")
|
||||
} else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok {
|
||||
log.Debug().
|
||||
Stringer("old_portal_mxid", portal.MXID).
|
||||
Msg("Inviter has no member event in old portal, deleting room")
|
||||
} else if targetUserMember.Membership.IsInviteOrJoin() {
|
||||
doCleanup = false
|
||||
} else {
|
||||
log.Debug().
|
||||
Stringer("old_portal_mxid", portal.MXID).
|
||||
Str("membership", string(targetUserMember.Membership)).
|
||||
Msg("Inviter is not in old portal, deleting room")
|
||||
}
|
||||
|
||||
if doCleanup {
|
||||
if err = portal.RemoveMXID(ctx); err != nil {
|
||||
log.Err(err).Msg("Failed to delete old portal mxid")
|
||||
} else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil {
|
||||
log.Err(err).Msg("Failed to clean up old portal room")
|
||||
}
|
||||
}
|
||||
}
|
||||
err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID())
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to ensure bot is invited to room")
|
||||
|
|
@ -209,67 +206,72 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen
|
|||
return EventHandlingResultFailed
|
||||
}
|
||||
|
||||
portal.roomCreateLock.Lock()
|
||||
defer portal.roomCreateLock.Unlock()
|
||||
portalMXID := portal.MXID
|
||||
if portalMXID != "" {
|
||||
sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL())
|
||||
rejectInvite(ctx, evt, br.Bot, "")
|
||||
return EventHandlingResultSuccess
|
||||
}
|
||||
err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to give permissions to bridge bot")
|
||||
sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot")
|
||||
rejectInvite(ctx, evt, br.Bot, "")
|
||||
return EventHandlingResultSuccess
|
||||
}
|
||||
overrideIntent := invitedGhost.Intent
|
||||
if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
|
||||
log.Debug().
|
||||
Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
|
||||
Msg("Created DM was redirected to another user ID")
|
||||
_, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
|
||||
Parsed: &event.MemberEventContent{
|
||||
Membership: event.MembershipLeave,
|
||||
Reason: "Direct chat redirected to another internal user ID",
|
||||
didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID)
|
||||
if didSetPortal {
|
||||
message := "Private chat portal created"
|
||||
err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent)
|
||||
hasWarning := false
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to give power to bot in new DM")
|
||||
message += "\n\nWarning: failed to promote bot"
|
||||
hasWarning = true
|
||||
}
|
||||
if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID {
|
||||
log.Debug().
|
||||
Str("dm_redirected_to_id", string(resp.DMRedirectedTo)).
|
||||
Msg("Created DM was redirected to another user ID")
|
||||
_, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{
|
||||
Parsed: &event.MemberEventContent{
|
||||
Membership: event.MembershipLeave,
|
||||
Reason: "Direct chat redirected to another internal user ID",
|
||||
},
|
||||
}, time.Time{})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
|
||||
}
|
||||
otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get ghost of real portal other user ID")
|
||||
} else {
|
||||
invitedGhost = otherUserGhost
|
||||
}
|
||||
}
|
||||
if resp.PortalInfo != nil {
|
||||
portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{})
|
||||
} else {
|
||||
portal.UpdateCapabilities(ctx, sourceLogin, true)
|
||||
portal.UpdateBridgeInfo(ctx)
|
||||
}
|
||||
// TODO this might become unnecessary if UpdateInfo starts taking care of it
|
||||
_, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{
|
||||
Parsed: &event.ElementFunctionalMembersContent{
|
||||
ServiceMembers: []id.UserID{br.Bot.GetMXID()},
|
||||
},
|
||||
}, time.Time{})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to make incorrect ghost leave new DM room")
|
||||
log.Warn().Err(err).Msg("Failed to set service members in room")
|
||||
if !hasWarning {
|
||||
message += "\n\nWarning: failed to set service members"
|
||||
hasWarning = true
|
||||
}
|
||||
}
|
||||
if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot {
|
||||
overrideIntent = br.Bot
|
||||
} else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil {
|
||||
log.Err(err).Msg("Failed to get ghost of real portal other user ID")
|
||||
} else {
|
||||
invitedGhost = otherUserGhost
|
||||
overrideIntent = otherUserGhost.Intent
|
||||
mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
|
||||
if ok {
|
||||
err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
|
||||
if err != nil {
|
||||
if hasWarning {
|
||||
message += fmt.Sprintf(", %s", err.Error())
|
||||
} else {
|
||||
message += fmt.Sprintf("\n\nWarning: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
sendNotice(ctx, evt, invitedGhost.Intent, message)
|
||||
} else {
|
||||
// TODO ensure user is invited even if PortalInfo wasn't provided?
|
||||
sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL())
|
||||
rejectInvite(ctx, evt, br.Bot, "")
|
||||
}
|
||||
err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{
|
||||
// We locked it before checking the mxid
|
||||
RoomCreateAlreadyLocked: true,
|
||||
|
||||
FailIfMXIDSet: true,
|
||||
ChatInfo: resp.PortalInfo,
|
||||
ChatInfoSource: sourceLogin,
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to update Matrix room ID for new DM portal")
|
||||
sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work")
|
||||
return EventHandlingResultSuccess
|
||||
}
|
||||
message := "Private chat portal created"
|
||||
mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling)
|
||||
if ok {
|
||||
err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Error in connector newly bridged room handler")
|
||||
message += fmt.Sprintf("\n\nWarning: %s", err.Error())
|
||||
}
|
||||
}
|
||||
sendNotice(ctx, evt, overrideIntent, message)
|
||||
return EventHandlingResultSuccess
|
||||
}
|
||||
|
||||
|
|
@ -292,3 +294,21 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
|
||||
portal.roomCreateLock.Lock()
|
||||
defer portal.roomCreateLock.Unlock()
|
||||
if portal.MXID != "" {
|
||||
return false
|
||||
}
|
||||
portal.MXID = roomID
|
||||
portal.updateLogger()
|
||||
portal.Bridge.cacheLock.Lock()
|
||||
portal.Bridge.portalsByMXID[portal.MXID] = portal
|
||||
portal.Bridge.cacheLock.Unlock()
|
||||
err := portal.Save(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
|
||||
type MessageStatusEventInfo struct {
|
||||
RoomID id.RoomID
|
||||
TransactionID string
|
||||
SourceEventID id.EventID
|
||||
NewEventID id.EventID
|
||||
EventType event.Type
|
||||
|
|
@ -42,7 +41,6 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo {
|
|||
|
||||
return &MessageStatusEventInfo{
|
||||
RoomID: evt.RoomID,
|
||||
TransactionID: evt.Unsigned.TransactionID,
|
||||
SourceEventID: evt.ID,
|
||||
EventType: evt.Type,
|
||||
MessageType: evt.Content.AsMessage().MsgType,
|
||||
|
|
@ -184,10 +182,9 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe
|
|||
Type: event.RelReference,
|
||||
EventID: evt.SourceEventID,
|
||||
},
|
||||
TargetTxnID: evt.TransactionID,
|
||||
Status: ms.Status,
|
||||
Reason: ms.ErrorReason,
|
||||
Message: ms.Message,
|
||||
Status: ms.Status,
|
||||
Reason: ms.ErrorReason,
|
||||
Message: ms.Message,
|
||||
}
|
||||
if ms.InternalError != nil {
|
||||
content.InternalError = ms.InternalError.Error()
|
||||
|
|
|
|||
|
|
@ -47,8 +47,8 @@ type PortalID string
|
|||
// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true.
|
||||
// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user.
|
||||
type PortalKey struct {
|
||||
ID PortalID `json:"portal_id"`
|
||||
Receiver UserLoginID `json:"portal_receiver,omitempty"`
|
||||
ID PortalID
|
||||
Receiver UserLoginID
|
||||
}
|
||||
|
||||
func (pk PortalKey) IsEmpty() bool {
|
||||
|
|
|
|||
|
|
@ -16,9 +16,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/configupgrade"
|
||||
"go.mau.fi/util/ptr"
|
||||
"go.mau.fi/util/random"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2/database"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
|
@ -119,15 +117,11 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa
|
|||
mediaPart.Content.EnsureHasHTML()
|
||||
mediaPart.Content.Body += "\n\n" + textPart.Content.Body
|
||||
mediaPart.Content.FormattedBody += "<br><br>" + textPart.Content.FormattedBody
|
||||
mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions)
|
||||
mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...)
|
||||
} else {
|
||||
mediaPart.Content.FileName = mediaPart.Content.Body
|
||||
mediaPart.Content.Body = textPart.Content.Body
|
||||
mediaPart.Content.Format = textPart.Content.Format
|
||||
mediaPart.Content.FormattedBody = textPart.Content.FormattedBody
|
||||
mediaPart.Content.Mentions = textPart.Content.Mentions
|
||||
mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews
|
||||
}
|
||||
if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok {
|
||||
metaMerger.CopyFrom(textPart.DBMetadata)
|
||||
|
|
@ -261,7 +255,6 @@ type NetworkConnector interface {
|
|||
}
|
||||
|
||||
type StoppableNetwork interface {
|
||||
NetworkConnector
|
||||
// Stop is called when the bridge is stopping, after all network clients have been disconnected.
|
||||
Stop()
|
||||
}
|
||||
|
|
@ -318,16 +311,6 @@ type MaxFileSizeingNetwork interface {
|
|||
SetMaxFileSize(maxSize int64)
|
||||
}
|
||||
|
||||
type NetworkResettingNetwork interface {
|
||||
NetworkConnector
|
||||
// ResetHTTPTransport should recreate the HTTP client used by the bridge.
|
||||
// It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable.
|
||||
ResetHTTPTransport()
|
||||
// ResetNetworkConnections should forcefully disconnect and restart any persistent network connections.
|
||||
// ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here.
|
||||
ResetNetworkConnections()
|
||||
}
|
||||
|
||||
type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error)
|
||||
|
||||
type MatrixMessageResponse struct {
|
||||
|
|
@ -359,16 +342,10 @@ type NetworkGeneralCapabilities struct {
|
|||
// Should the bridge re-request user info on incoming messages even if the ghost already has info?
|
||||
// By default, info is only requested for ghosts with no name, and other updating is left to events.
|
||||
AggressiveUpdateInfo bool
|
||||
// Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message?
|
||||
// This should be enabled if the network requires each message to be marked as read independently,
|
||||
// and doesn't automatically do it when sending a message.
|
||||
ImplicitReadReceipts bool
|
||||
// If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave])
|
||||
// to handle asynchronous message responses, this field can be set to enable
|
||||
// automatic timeout errors in case the asynchronous response never arrives.
|
||||
OutgoingMessageTimeouts *OutgoingTimeoutConfig
|
||||
// Capabilities related to the provisioning API.
|
||||
Provisioning ProvisioningCapabilities
|
||||
}
|
||||
|
||||
// NetworkAPI is an interface representing a remote network client for a single user login.
|
||||
|
|
@ -702,35 +679,6 @@ type RoomTopicHandlingNetworkAPI interface {
|
|||
HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error)
|
||||
}
|
||||
|
||||
type DisappearTimerChangingNetworkAPI interface {
|
||||
NetworkAPI
|
||||
// HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed.
|
||||
// This method should update the Disappear field of the Portal with the new timer and return true
|
||||
// if the change was successful. If the change is not successful, then the field should not be updated.
|
||||
HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error)
|
||||
}
|
||||
|
||||
// DeleteChatHandlingNetworkAPI is an optional interface that network connectors
|
||||
// can implement to delete a chat from the remote network.
|
||||
type DeleteChatHandlingNetworkAPI interface {
|
||||
NetworkAPI
|
||||
// HandleMatrixDeleteChat is called when the user explicitly deletes a chat.
|
||||
HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error
|
||||
}
|
||||
|
||||
// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors
|
||||
// can implement to accept message requests from the remote network.
|
||||
type MessageRequestAcceptingNetworkAPI interface {
|
||||
NetworkAPI
|
||||
// HandleMatrixAcceptMessageRequest is called when the user accepts a message request.
|
||||
HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error
|
||||
}
|
||||
|
||||
type BeeperAIStreamHandlingNetworkAPI interface {
|
||||
NetworkAPI
|
||||
HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error
|
||||
}
|
||||
|
||||
type ResolveIdentifierResponse struct {
|
||||
// Ghost is the ghost of the user that the identifier resolves to.
|
||||
// This field should be set whenever possible. However, it is not required,
|
||||
|
|
@ -750,8 +698,6 @@ type ResolveIdentifierResponse struct {
|
|||
Chat *CreateChatResponse
|
||||
}
|
||||
|
||||
var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10))
|
||||
|
||||
type CreateChatResponse struct {
|
||||
PortalKey networkid.PortalKey
|
||||
// Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary.
|
||||
|
|
@ -760,17 +706,6 @@ type CreateChatResponse struct {
|
|||
// If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user,
|
||||
// this field should have the user ID of said different user.
|
||||
DMRedirectedTo networkid.UserID
|
||||
|
||||
FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant
|
||||
}
|
||||
|
||||
type CreateChatFailedParticipant struct {
|
||||
Reason string `json:"reason"`
|
||||
InviteEventType string `json:"invite_event_type,omitempty"`
|
||||
InviteContent *event.Content `json:"invite_content,omitempty"`
|
||||
|
||||
UserMXID id.UserID `json:"user_mxid,omitempty"`
|
||||
DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"`
|
||||
}
|
||||
|
||||
// IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats.
|
||||
|
|
@ -805,83 +740,7 @@ type UserSearchingNetworkAPI interface {
|
|||
|
||||
type GroupCreatingNetworkAPI interface {
|
||||
IdentifierResolvingNetworkAPI
|
||||
CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error)
|
||||
}
|
||||
|
||||
type PersonalFilteringCustomizingNetworkAPI interface {
|
||||
NetworkAPI
|
||||
CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom)
|
||||
}
|
||||
|
||||
type ProvisioningCapabilities struct {
|
||||
ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"`
|
||||
GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"`
|
||||
}
|
||||
|
||||
type ResolveIdentifierCapabilities struct {
|
||||
// Can DMs be created after resolving an identifier?
|
||||
CreateDM bool `json:"create_dm"`
|
||||
// Can users be looked up by phone number?
|
||||
LookupPhone bool `json:"lookup_phone"`
|
||||
// Can users be looked up by email address?
|
||||
LookupEmail bool `json:"lookup_email"`
|
||||
// Can users be looked up by network-specific username?
|
||||
LookupUsername bool `json:"lookup_username"`
|
||||
// Can any phone number be contacted without having to validate it via lookup first?
|
||||
AnyPhone bool `json:"any_phone"`
|
||||
// Can a contact list be retrieved from the bridge?
|
||||
ContactList bool `json:"contact_list"`
|
||||
// Can users be searched by name on the remote network?
|
||||
Search bool `json:"search"`
|
||||
}
|
||||
|
||||
type GroupTypeCapabilities struct {
|
||||
TypeDescription string `json:"type_description"`
|
||||
|
||||
Name GroupFieldCapability `json:"name"`
|
||||
Username GroupFieldCapability `json:"username"`
|
||||
Avatar GroupFieldCapability `json:"avatar"`
|
||||
Topic GroupFieldCapability `json:"topic"`
|
||||
Disappear GroupFieldCapability `json:"disappear"`
|
||||
Participants GroupFieldCapability `json:"participants"`
|
||||
Parent GroupFieldCapability `json:"parent"`
|
||||
}
|
||||
|
||||
type GroupFieldCapability struct {
|
||||
// Is setting this field allowed at all in the create request?
|
||||
// Even if false, the network connector should attempt to set the metadata after group creation,
|
||||
// as the allowed flag can't be enforced properly when creating a group for an existing Matrix room.
|
||||
Allowed bool `json:"allowed"`
|
||||
// Is setting this field mandatory for the creation to succeed?
|
||||
Required bool `json:"required,omitempty"`
|
||||
// The minimum/maximum length of the field, if applicable.
|
||||
// For members, length means the number of members excluding the creator.
|
||||
MinLength int `json:"min_length,omitempty"`
|
||||
MaxLength int `json:"max_length,omitempty"`
|
||||
|
||||
// Only for the disappear field: allowed disappearing settings
|
||||
DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"`
|
||||
|
||||
// This can be used to tell provisionutil not to call ValidateUserID on each participant.
|
||||
// It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs.
|
||||
SkipIdentifierValidation bool `json:"-"`
|
||||
}
|
||||
|
||||
type GroupCreateParams struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
|
||||
Username string `json:"username,omitempty"`
|
||||
// Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs
|
||||
Participants []networkid.UserID `json:"participants,omitempty"`
|
||||
Parent *networkid.PortalKey `json:"parent,omitempty"`
|
||||
|
||||
Name *event.RoomNameEventContent `json:"name,omitempty"`
|
||||
Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"`
|
||||
Topic *event.TopicEventContent `json:"topic,omitempty"`
|
||||
Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"`
|
||||
|
||||
// An existing room ID to bridge to. If unset, a new room will be created.
|
||||
RoomID id.RoomID `json:"room_id,omitempty"`
|
||||
CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error)
|
||||
}
|
||||
|
||||
type MembershipChangeType struct {
|
||||
|
|
@ -921,15 +780,16 @@ type MatrixMembershipChange struct {
|
|||
MatrixRoomMeta[*event.MemberEventContent]
|
||||
Target GhostOrUserLogin
|
||||
Type MembershipChangeType
|
||||
}
|
||||
|
||||
type MatrixMembershipResult struct {
|
||||
RedirectTo networkid.UserID
|
||||
// Deprecated: Use Target instead
|
||||
TargetGhost *Ghost
|
||||
// Deprecated: Use Target instead
|
||||
TargetUserLogin *UserLogin
|
||||
}
|
||||
|
||||
type MembershipHandlingNetworkAPI interface {
|
||||
NetworkAPI
|
||||
HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error)
|
||||
HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error)
|
||||
}
|
||||
|
||||
type SinglePowerLevelChange struct {
|
||||
|
|
@ -1168,11 +1028,6 @@ type RemoteChatDelete interface {
|
|||
RemoteDeleteOnlyForMe
|
||||
}
|
||||
|
||||
type RemoteChatDeleteWithChildren interface {
|
||||
RemoteChatDelete
|
||||
DeleteChildren() bool
|
||||
}
|
||||
|
||||
type RemoteEventThatMayCreatePortal interface {
|
||||
RemoteEvent
|
||||
ShouldCreatePortal() bool
|
||||
|
|
@ -1405,14 +1260,12 @@ type MatrixMessageRemove struct {
|
|||
|
||||
type MatrixRoomMeta[ContentType any] struct {
|
||||
MatrixEventBase[ContentType]
|
||||
PrevContent ContentType
|
||||
IsStateRequest bool
|
||||
PrevContent ContentType
|
||||
}
|
||||
|
||||
type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent]
|
||||
type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent]
|
||||
type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent]
|
||||
type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer]
|
||||
|
||||
type MatrixReadReceipt struct {
|
||||
Portal *Portal
|
||||
|
|
@ -1427,8 +1280,6 @@ type MatrixReadReceipt struct {
|
|||
LastRead time.Time
|
||||
// The receipt metadata.
|
||||
Receipt event.ReadReceipt
|
||||
// Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt.
|
||||
Implicit bool
|
||||
}
|
||||
|
||||
type MatrixTyping struct {
|
||||
|
|
@ -1442,9 +1293,6 @@ type MatrixViewingChat struct {
|
|||
Portal *Portal
|
||||
}
|
||||
|
||||
type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent]
|
||||
type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent]
|
||||
type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent]
|
||||
type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent]
|
||||
type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent]
|
||||
type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent]
|
||||
|
|
|
|||
1341
bridgev2/portal.go
1341
bridgev2/portal.go
File diff suppressed because it is too large
Load diff
|
|
@ -194,9 +194,6 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t
|
|||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get last thread message")
|
||||
return
|
||||
} else if anchorMessage == nil {
|
||||
log.Warn().Msg("No messages found in thread?")
|
||||
return
|
||||
}
|
||||
resp := portal.fetchThreadBackfill(ctx, source, anchorMessage)
|
||||
if resp != nil {
|
||||
|
|
@ -342,7 +339,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
|
|||
for i, part := range msg.Parts {
|
||||
partIDs = append(partIDs, part.ID)
|
||||
portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent)
|
||||
part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent()
|
||||
evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID)
|
||||
dbMessage := &database.Message{
|
||||
ID: msg.ID,
|
||||
|
|
@ -383,23 +379,19 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
|
|||
prevThreadEvent.MXID = evtID
|
||||
out.PrevThreadEvents[*msg.ThreadRoot] = evtID
|
||||
}
|
||||
if msg.Disappear.Type != event.DisappearingTypeNone {
|
||||
if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
|
||||
if msg.Disappear.Type != database.DisappearingTypeNone {
|
||||
if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() {
|
||||
msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer)
|
||||
}
|
||||
out.Disappear = append(out.Disappear, &database.DisappearingMessage{
|
||||
RoomID: portal.MXID,
|
||||
EventID: evtID,
|
||||
Timestamp: msg.Timestamp,
|
||||
DisappearingSetting: msg.Disappear,
|
||||
})
|
||||
}
|
||||
}
|
||||
slices.Sort(partIDs)
|
||||
for _, reaction := range msg.Reactions {
|
||||
if reaction == nil {
|
||||
continue
|
||||
}
|
||||
reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove)
|
||||
if !ok {
|
||||
continue
|
||||
|
|
@ -410,7 +402,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin
|
|||
if reaction.Timestamp.IsZero() {
|
||||
reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond)
|
||||
}
|
||||
//lint:ignore SA4006 it's a todo
|
||||
targetPart, ok := partMap[*reaction.TargetPart]
|
||||
if !ok {
|
||||
// TODO warning log and/or skip reaction?
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ func (portal *PortalInternals) EventLoop() {
|
|||
(*Portal)(portal).eventLoop()
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) {
|
||||
return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt)
|
||||
func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) {
|
||||
return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context {
|
||||
|
|
@ -49,10 +49,6 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any
|
|||
(*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error {
|
||||
return (*Portal)(portal).unwrapBeeperSendState(ctx, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) {
|
||||
(*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID)
|
||||
}
|
||||
|
|
@ -65,8 +61,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i
|
|||
return (*Portal)(portal).checkConfusableName(ctx, userID, name)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest)
|
||||
func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult {
|
||||
|
|
@ -77,10 +73,6 @@ func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user
|
|||
(*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) {
|
||||
(*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixTyping(ctx, evt)
|
||||
}
|
||||
|
|
@ -125,24 +117,12 @@ func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.User
|
|||
return (*Portal)(portal).getTargetUser(ctx, userID)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt)
|
||||
func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixTombstone(ctx, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) {
|
||||
(*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser)
|
||||
func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
|
||||
return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult {
|
||||
|
|
@ -153,10 +133,6 @@ func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *Us
|
|||
return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) {
|
||||
(*Portal)(portal).ensureFunctionalMember(ctx, ghost)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) {
|
||||
return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType)
|
||||
}
|
||||
|
|
@ -257,10 +233,6 @@ func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, sourc
|
|||
return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) {
|
||||
return (*Portal)(portal).findOtherLogins(ctx, source)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult {
|
||||
return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt)
|
||||
}
|
||||
|
|
@ -269,16 +241,16 @@ func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source
|
|||
return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
|
||||
return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline)
|
||||
func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool {
|
||||
return (*Portal)(portal).updateName(ctx, name, sender, ts)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
|
||||
return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline)
|
||||
func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool {
|
||||
return (*Portal)(portal).updateTopic(ctx, topic, sender, ts)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool {
|
||||
return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline)
|
||||
func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool {
|
||||
return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) GetBridgeInfoStateKey() string {
|
||||
|
|
@ -293,12 +265,8 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen
|
|||
return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool {
|
||||
return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) {
|
||||
(*Portal)(portal).revertRoomMeta(ctx, evt)
|
||||
func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool {
|
||||
return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) {
|
||||
|
|
@ -309,10 +277,6 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha
|
|||
return (*Portal)(portal).updateOtherUser(ctx, members)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool {
|
||||
return (*Portal)(portal).roomIsPublic(ctx)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error {
|
||||
return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts)
|
||||
}
|
||||
|
|
@ -333,10 +297,6 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc
|
|||
return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) {
|
||||
(*Portal)(portal).addToUserSpaces(ctx)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) {
|
||||
(*Portal)(portal).removeInPortalCache(ctx)
|
||||
}
|
||||
|
|
@ -400,3 +360,7 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save
|
|||
func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error {
|
||||
return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove)
|
||||
}
|
||||
|
||||
func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool {
|
||||
return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,40 +32,21 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
|
|||
if source == target {
|
||||
return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same")
|
||||
}
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("action", "re-id portal").
|
||||
Stringer("source_portal_key", source).
|
||||
Stringer("target_portal_key", target).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.Debug().Msg("Re-ID'ing portal")
|
||||
defer func() {
|
||||
log.Debug().Msg("Finished handling portal re-ID")
|
||||
}()
|
||||
acquireCacheLock := func() {
|
||||
if !br.cacheLock.TryLock() {
|
||||
log.Debug().Msg("Waiting for global cache lock")
|
||||
br.cacheLock.Lock()
|
||||
log.Debug().Msg("Acquired global cache lock after waiting")
|
||||
} else {
|
||||
log.Trace().Msg("Acquired global cache lock without waiting")
|
||||
}
|
||||
}
|
||||
log.Debug().Msg("Re-ID'ing portal")
|
||||
sourcePortal, err := br.GetExistingPortalByKey(ctx, source)
|
||||
br.cacheLock.Lock()
|
||||
defer br.cacheLock.Unlock()
|
||||
sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true)
|
||||
if err != nil {
|
||||
return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err)
|
||||
} else if sourcePortal == nil {
|
||||
log.Debug().Msg("Source portal not found, re-ID is no-op")
|
||||
return ReIDResultNoOp, nil, nil
|
||||
}
|
||||
if !sourcePortal.roomCreateLock.TryLock() {
|
||||
if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
|
||||
(*cancelCreate)()
|
||||
}
|
||||
log.Debug().Msg("Waiting for source portal room creation lock")
|
||||
sourcePortal.roomCreateLock.Lock()
|
||||
log.Debug().Msg("Acquired source portal room creation lock after waiting")
|
||||
}
|
||||
sourcePortal.roomCreateLock.Lock()
|
||||
defer sourcePortal.roomCreateLock.Unlock()
|
||||
if sourcePortal.MXID == "" {
|
||||
log.Info().Msg("Source portal doesn't have Matrix room, deleting row")
|
||||
|
|
@ -78,37 +59,22 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
|
|||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Stringer("source_portal_mxid", sourcePortal.MXID)
|
||||
})
|
||||
|
||||
acquireCacheLock()
|
||||
targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true)
|
||||
if err != nil {
|
||||
br.cacheLock.Unlock()
|
||||
return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err)
|
||||
}
|
||||
if targetPortal == nil {
|
||||
log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal")
|
||||
err = sourcePortal.unlockedReID(ctx, target)
|
||||
br.cacheLock.Unlock()
|
||||
if err != nil {
|
||||
return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err)
|
||||
}
|
||||
return ReIDResultSourceReIDd, sourcePortal, nil
|
||||
}
|
||||
br.cacheLock.Unlock()
|
||||
|
||||
if !targetPortal.roomCreateLock.TryLock() {
|
||||
if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil {
|
||||
(*cancelCreate)()
|
||||
}
|
||||
log.Debug().Msg("Waiting for target portal room creation lock")
|
||||
targetPortal.roomCreateLock.Lock()
|
||||
log.Debug().Msg("Acquired target portal room creation lock after waiting")
|
||||
}
|
||||
targetPortal.roomCreateLock.Lock()
|
||||
defer targetPortal.roomCreateLock.Unlock()
|
||||
if targetPortal.MXID == "" {
|
||||
log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal")
|
||||
acquireCacheLock()
|
||||
defer br.cacheLock.Unlock()
|
||||
err = targetPortal.unlockedDelete(ctx)
|
||||
if err != nil {
|
||||
return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err)
|
||||
|
|
@ -123,9 +89,6 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
|
|||
return c.Stringer("target_portal_mxid", targetPortal.MXID)
|
||||
})
|
||||
log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal")
|
||||
sourcePortal.removeInPortalCache(ctx)
|
||||
acquireCacheLock()
|
||||
defer br.cacheLock.Unlock()
|
||||
err = sourcePortal.unlockedDelete(ctx)
|
||||
if err != nil {
|
||||
return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err)
|
||||
|
|
@ -133,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta
|
|||
go func() {
|
||||
_, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{
|
||||
Parsed: &event.TombstoneEventContent{
|
||||
Body: "This room has been merged",
|
||||
Body: fmt.Sprintf("This room has been merged"),
|
||||
ReplacementRoom: targetPortal.MXID,
|
||||
},
|
||||
}, time.Now())
|
||||
|
|
|
|||
|
|
@ -1,149 +0,0 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package provisionutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/ptr"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type RespCreateGroup struct {
|
||||
ID networkid.PortalID `json:"id"`
|
||||
MXID id.RoomID `json:"mxid"`
|
||||
Portal *bridgev2.Portal `json:"-"`
|
||||
|
||||
FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"`
|
||||
}
|
||||
|
||||
func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) {
|
||||
api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI)
|
||||
if !ok {
|
||||
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups"))
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Any("create_params", params).
|
||||
Msg("Creating group chat on remote network")
|
||||
caps := login.Bridge.Network.GetCapabilities()
|
||||
typeSpec, validType := caps.Provisioning.GroupCreation[params.Type]
|
||||
if !validType {
|
||||
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type))
|
||||
}
|
||||
if len(params.Participants) < typeSpec.Participants.MinLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength))
|
||||
} else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength))
|
||||
}
|
||||
userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
|
||||
for i, participant := range params.Participants {
|
||||
parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant))
|
||||
if ok {
|
||||
participant = parsedParticipant
|
||||
params.Participants[i] = participant
|
||||
}
|
||||
if !typeSpec.Participants.SkipIdentifierValidation {
|
||||
if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant))
|
||||
}
|
||||
}
|
||||
if api.IsThisUser(ctx, participant) {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant))
|
||||
}
|
||||
}
|
||||
if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required"))
|
||||
} else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength))
|
||||
} else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength))
|
||||
}
|
||||
if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required"))
|
||||
}
|
||||
if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required"))
|
||||
} else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength))
|
||||
} else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength))
|
||||
}
|
||||
if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required"))
|
||||
} else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer"))
|
||||
}
|
||||
if params.Username == "" && typeSpec.Username.Required {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required"))
|
||||
} else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength))
|
||||
} else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength))
|
||||
}
|
||||
if params.Parent == nil && typeSpec.Parent.Required {
|
||||
return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required"))
|
||||
}
|
||||
resp, err := api.CreateGroup(ctx, params)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to create group")
|
||||
return nil, err
|
||||
}
|
||||
if resp.PortalKey.IsEmpty() {
|
||||
return nil, ErrNoPortalKey
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Object("portal_key", resp.PortalKey).
|
||||
Msg("Successfully created group on remote network")
|
||||
if resp.Portal == nil {
|
||||
resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
|
||||
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
|
||||
}
|
||||
}
|
||||
if resp.Portal.MXID == "" {
|
||||
err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
|
||||
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
|
||||
}
|
||||
}
|
||||
for key, fp := range resp.FailedParticipants {
|
||||
if fp.InviteEventType == "" {
|
||||
fp.InviteEventType = event.EventMessage.Type
|
||||
}
|
||||
if fp.UserMXID == "" {
|
||||
ghost, err := login.Bridge.GetGhostByID(ctx, key)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant")
|
||||
} else if ghost != nil {
|
||||
fp.UserMXID = ghost.Intent.GetMXID()
|
||||
}
|
||||
}
|
||||
if fp.DMRoomMXID == "" {
|
||||
portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant")
|
||||
} else if portal != nil {
|
||||
fp.DMRoomMXID = portal.MXID
|
||||
}
|
||||
}
|
||||
}
|
||||
return &RespCreateGroup{
|
||||
ID: resp.Portal.ID,
|
||||
MXID: resp.Portal.MXID,
|
||||
Portal: resp.Portal,
|
||||
|
||||
FailedParticipants: resp.FailedParticipants,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package provisionutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
)
|
||||
|
||||
type RespGetContactList struct {
|
||||
Contacts []*RespResolveIdentifier `json:"contacts"`
|
||||
}
|
||||
|
||||
type RespSearchUsers struct {
|
||||
Results []*RespResolveIdentifier `json:"results"`
|
||||
}
|
||||
|
||||
func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) {
|
||||
api, ok := login.Client.(bridgev2.ContactListingNetworkAPI)
|
||||
if !ok {
|
||||
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts"))
|
||||
}
|
||||
resp, err := api.GetContactList(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
|
||||
return nil, err
|
||||
}
|
||||
return &RespGetContactList{
|
||||
Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) {
|
||||
api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI)
|
||||
if !ok {
|
||||
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users"))
|
||||
}
|
||||
resp, err := api.SearchUsers(ctx, query)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list")
|
||||
return nil, err
|
||||
}
|
||||
return &RespSearchUsers{
|
||||
Results: processResolveIdentifiers(ctx, login.Bridge, resp, true),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) {
|
||||
apiResp = make([]*RespResolveIdentifier, len(resp))
|
||||
for i, contact := range resp {
|
||||
apiContact := &RespResolveIdentifier{
|
||||
ID: contact.UserID,
|
||||
}
|
||||
apiResp[i] = apiContact
|
||||
if contact.UserInfo != nil {
|
||||
if contact.UserInfo.Name != nil {
|
||||
apiContact.Name = *contact.UserInfo.Name
|
||||
}
|
||||
if contact.UserInfo.Identifiers != nil {
|
||||
apiContact.Identifiers = contact.UserInfo.Identifiers
|
||||
}
|
||||
}
|
||||
if contact.Ghost != nil {
|
||||
if syncInfo && contact.UserInfo != nil {
|
||||
contact.Ghost.UpdateInfo(ctx, contact.UserInfo)
|
||||
}
|
||||
if contact.Ghost.Name != "" {
|
||||
apiContact.Name = contact.Ghost.Name
|
||||
}
|
||||
if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) {
|
||||
apiContact.Identifiers = contact.Ghost.Identifiers
|
||||
}
|
||||
apiContact.AvatarURL = contact.Ghost.AvatarMXC
|
||||
apiContact.MXID = contact.Ghost.Intent.GetMXID()
|
||||
}
|
||||
if contact.Chat != nil {
|
||||
if contact.Chat.Portal == nil {
|
||||
var err error
|
||||
contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
|
||||
}
|
||||
}
|
||||
if contact.Chat.Portal != nil {
|
||||
apiContact.DMRoomID = contact.Chat.Portal.MXID
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package provisionutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type RespResolveIdentifier struct {
|
||||
ID networkid.UserID `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
AvatarURL id.ContentURIString `json:"avatar_url,omitempty"`
|
||||
Identifiers []string `json:"identifiers,omitempty"`
|
||||
MXID id.UserID `json:"mxid,omitempty"`
|
||||
DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"`
|
||||
|
||||
Portal *bridgev2.Portal `json:"-"`
|
||||
Ghost *bridgev2.Ghost `json:"-"`
|
||||
JustCreated bool `json:"-"`
|
||||
}
|
||||
|
||||
var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request")
|
||||
|
||||
func ResolveIdentifier(
|
||||
ctx context.Context,
|
||||
login *bridgev2.UserLogin,
|
||||
identifier string,
|
||||
createChat bool,
|
||||
) (*RespResolveIdentifier, error) {
|
||||
api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI)
|
||||
if !ok {
|
||||
return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers"))
|
||||
}
|
||||
var resp *bridgev2.ResolveIdentifierResponse
|
||||
parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier))
|
||||
validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork)
|
||||
if ok && (!vOK || validator.ValidateUserID(parsedUserID)) {
|
||||
ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID")
|
||||
return nil, err
|
||||
}
|
||||
resp = &bridgev2.ResolveIdentifierResponse{
|
||||
Ghost: ghost,
|
||||
UserID: parsedUserID,
|
||||
}
|
||||
gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI)
|
||||
if ok && createChat {
|
||||
resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat")
|
||||
return nil, err
|
||||
}
|
||||
} else if createChat || ghost.Name == "" {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Bool("create_chat", createChat).
|
||||
Bool("has_name", ghost.Name != "").
|
||||
Msg("Falling back to resolving identifier")
|
||||
resp = nil
|
||||
identifier = string(parsedUserID)
|
||||
}
|
||||
}
|
||||
if resp == nil {
|
||||
var err error
|
||||
resp, err = api.ResolveIdentifier(ctx, identifier, createChat)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier")
|
||||
return nil, err
|
||||
} else if resp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
apiResp := &RespResolveIdentifier{
|
||||
ID: resp.UserID,
|
||||
Ghost: resp.Ghost,
|
||||
}
|
||||
if resp.Ghost != nil {
|
||||
if resp.UserInfo != nil {
|
||||
resp.Ghost.UpdateInfo(ctx, resp.UserInfo)
|
||||
}
|
||||
apiResp.Name = resp.Ghost.Name
|
||||
apiResp.AvatarURL = resp.Ghost.AvatarMXC
|
||||
apiResp.Identifiers = resp.Ghost.Identifiers
|
||||
apiResp.MXID = resp.Ghost.Intent.GetMXID()
|
||||
} else if resp.UserInfo != nil && resp.UserInfo.Name != nil {
|
||||
apiResp.Name = *resp.UserInfo.Name
|
||||
}
|
||||
if resp.Chat != nil {
|
||||
if resp.Chat.PortalKey.IsEmpty() {
|
||||
return nil, ErrNoPortalKey
|
||||
}
|
||||
if resp.Chat.Portal == nil {
|
||||
var err error
|
||||
resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal")
|
||||
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal"))
|
||||
}
|
||||
}
|
||||
resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID)
|
||||
if createChat && resp.Chat.Portal.MXID == "" {
|
||||
apiResp.JustCreated = true
|
||||
err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room")
|
||||
return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room"))
|
||||
}
|
||||
}
|
||||
apiResp.Portal = resp.Chat.Portal
|
||||
apiResp.DMRoomID = resp.Chat.Portal.MXID
|
||||
}
|
||||
return apiResp, nil
|
||||
}
|
||||
|
|
@ -63,13 +63,6 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve
|
|||
return true
|
||||
}
|
||||
|
||||
var (
|
||||
ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
|
||||
ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
|
||||
ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage())
|
||||
ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage()
|
||||
)
|
||||
|
||||
func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult {
|
||||
// TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands
|
||||
|
||||
|
|
@ -85,11 +78,13 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
|
|||
return EventHandlingResultFailed
|
||||
} else if sender == nil {
|
||||
log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event")
|
||||
br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
|
||||
status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
|
||||
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
|
||||
return EventHandlingResultFailed
|
||||
} else if !sender.Permissions.SendEvents {
|
||||
if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") {
|
||||
br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt))
|
||||
status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
|
||||
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
|
||||
}
|
||||
return EventHandlingResultIgnored
|
||||
} else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") {
|
||||
|
|
@ -97,7 +92,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
|
|||
}
|
||||
} else if evt.Type.Class != event.EphemeralEventType {
|
||||
log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event")
|
||||
br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt))
|
||||
status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage()
|
||||
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
|
||||
return EventHandlingResultIgnored
|
||||
}
|
||||
if evt.Type == event.EventMessage && sender != nil {
|
||||
|
|
@ -106,7 +102,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH
|
|||
msg.RemovePerMessageProfileFallback()
|
||||
if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom {
|
||||
if !sender.Permissions.Commands {
|
||||
br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt))
|
||||
status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()
|
||||
br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt))
|
||||
return EventHandlingResultIgnored
|
||||
}
|
||||
go br.Commands.Handle(
|
||||
|
|
@ -160,27 +157,10 @@ type EventHandlingResult struct {
|
|||
Ignored bool
|
||||
Queued bool
|
||||
|
||||
SkipStateEcho bool
|
||||
|
||||
// Error is an optional reason for failure. It is not required, Success may be false even without a specific error.
|
||||
Error error
|
||||
// Whether the Error should be sent as a MSS event.
|
||||
SendMSS bool
|
||||
|
||||
// EventID from the network
|
||||
EventID id.EventID
|
||||
// Stream order from the network
|
||||
StreamOrder int64
|
||||
}
|
||||
|
||||
func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult {
|
||||
ehr.EventID = id
|
||||
return ehr
|
||||
}
|
||||
|
||||
func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult {
|
||||
ehr.StreamOrder = order
|
||||
return ehr
|
||||
}
|
||||
|
||||
func (ehr EventHandlingResult) WithError(err error) EventHandlingResult {
|
||||
|
|
@ -197,11 +177,6 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult {
|
|||
return ehr
|
||||
}
|
||||
|
||||
func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult {
|
||||
ehr.SkipStateEcho = skip
|
||||
return ehr
|
||||
}
|
||||
|
||||
func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult {
|
||||
if err == nil {
|
||||
return ehr
|
||||
|
|
@ -220,7 +195,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult {
|
|||
return ul.Bridge.QueueRemoteEvent(ul, evt)
|
||||
}
|
||||
|
||||
func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult {
|
||||
func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) {
|
||||
log := login.Log
|
||||
ctx := log.WithContext(br.BackgroundCtx)
|
||||
maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver)
|
||||
|
|
@ -236,14 +211,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandl
|
|||
if err != nil {
|
||||
log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain).
|
||||
Msg("Failed to get portal to handle remote event")
|
||||
return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err))
|
||||
return
|
||||
} else if portal == nil {
|
||||
log.Warn().
|
||||
Stringer("event_type", evt.GetType()).
|
||||
Object("portal_key", key).
|
||||
Bool("uncertain_receiver", isUncertain).
|
||||
Msg("Portal not found to handle remote event")
|
||||
return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler)
|
||||
return
|
||||
}
|
||||
// TODO put this in a better place, and maybe cache to avoid constant db queries
|
||||
login.MarkInPortal(ctx, portal)
|
||||
|
|
|
|||
|
|
@ -65,19 +65,14 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal)
|
|||
type ChatDelete struct {
|
||||
EventMeta
|
||||
OnlyForMe bool
|
||||
Children bool
|
||||
}
|
||||
|
||||
var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil)
|
||||
var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil)
|
||||
|
||||
func (evt *ChatDelete) DeleteOnlyForMe() bool {
|
||||
return evt.OnlyForMe
|
||||
}
|
||||
|
||||
func (evt *ChatDelete) DeleteChildren() bool {
|
||||
return evt.Children
|
||||
}
|
||||
|
||||
// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange].
|
||||
type ChatInfoChange struct {
|
||||
EventMeta
|
||||
|
|
|
|||
|
|
@ -59,41 +59,6 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID {
|
|||
return evt.TransactionID
|
||||
}
|
||||
|
||||
// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data.
|
||||
type PreConvertedMessage struct {
|
||||
EventMeta
|
||||
Data *bridgev2.ConvertedMessage
|
||||
ID networkid.MessageID
|
||||
TransactionID networkid.TransactionID
|
||||
|
||||
HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error)
|
||||
}
|
||||
|
||||
var (
|
||||
_ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil)
|
||||
_ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil)
|
||||
_ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil)
|
||||
)
|
||||
|
||||
func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) {
|
||||
return evt.Data, nil
|
||||
}
|
||||
|
||||
func (evt *PreConvertedMessage) GetID() networkid.MessageID {
|
||||
return evt.ID
|
||||
}
|
||||
|
||||
func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID {
|
||||
return evt.TransactionID
|
||||
}
|
||||
|
||||
func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) {
|
||||
if evt.HandleExistingFunc == nil {
|
||||
return bridgev2.UpsertResult{}, nil
|
||||
}
|
||||
return evt.HandleExistingFunc(ctx, portal, intent, existing)
|
||||
}
|
||||
|
||||
type MessageRemove struct {
|
||||
EventMeta
|
||||
|
||||
|
|
|
|||
|
|
@ -101,18 +101,6 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E
|
|||
return evt
|
||||
}
|
||||
|
||||
func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta {
|
||||
origFunc := evt.LogContext
|
||||
if origFunc == nil {
|
||||
evt.LogContext = f
|
||||
return evt
|
||||
}
|
||||
evt.LogContext = func(c zerolog.Context) zerolog.Context {
|
||||
return f(origFunc(c))
|
||||
}
|
||||
return evt
|
||||
}
|
||||
|
||||
func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta {
|
||||
evt.PortalKey = p
|
||||
return evt
|
||||
|
|
|
|||
|
|
@ -164,17 +164,14 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) {
|
|||
ul.UserMXID: 50,
|
||||
},
|
||||
},
|
||||
Invite: []id.UserID{ul.UserMXID},
|
||||
RoomVersion: id.RoomV11,
|
||||
Invite: []id.UserID{ul.UserMXID},
|
||||
}
|
||||
if autoJoin {
|
||||
req.BeeperInitialMembers = []id.UserID{ul.UserMXID}
|
||||
// TODO remove this after initial_members is supported in hungryserv
|
||||
req.BeeperAutoJoinInvites = true
|
||||
}
|
||||
pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI)
|
||||
if ok {
|
||||
pfc.CustomizePersonalFilteringSpace(req)
|
||||
}
|
||||
ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create space room: %w", err)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,9 @@ import (
|
|||
|
||||
"github.com/tidwall/sjson"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"go.mau.fi/util/ptr"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/bridgev2/networkid"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
|
|
@ -88,8 +87,6 @@ type RemoteProfile struct {
|
|||
Username string `json:"username,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Avatar id.ContentURIString `json:"avatar,omitempty"`
|
||||
|
||||
AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"`
|
||||
}
|
||||
|
||||
func coalesce[T ~string](a, b T) T {
|
||||
|
|
@ -105,14 +102,11 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile {
|
|||
other.Username = coalesce(rp.Username, other.Username)
|
||||
other.Name = coalesce(rp.Name, other.Name)
|
||||
other.Avatar = coalesce(rp.Avatar, other.Avatar)
|
||||
if rp.AvatarFile != nil {
|
||||
other.AvatarFile = rp.AvatarFile
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
func (rp *RemoteProfile) IsZero() bool {
|
||||
return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil)
|
||||
func (rp *RemoteProfile) IsEmpty() bool {
|
||||
return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "")
|
||||
}
|
||||
|
||||
type BridgeState struct {
|
||||
|
|
@ -126,10 +120,10 @@ type BridgeState struct {
|
|||
|
||||
UserAction BridgeStateUserAction `json:"user_action,omitempty"`
|
||||
|
||||
UserID id.UserID `json:"user_id,omitempty"`
|
||||
RemoteID networkid.UserLoginID `json:"remote_id,omitempty"`
|
||||
RemoteName string `json:"remote_name,omitempty"`
|
||||
RemoteProfile RemoteProfile `json:"remote_profile,omitzero"`
|
||||
UserID id.UserID `json:"user_id,omitempty"`
|
||||
RemoteID string `json:"remote_id,omitempty"`
|
||||
RemoteName string `json:"remote_name,omitempty"`
|
||||
RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"`
|
||||
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Info map[string]interface{} `json:"info,omitempty"`
|
||||
|
|
@ -209,7 +203,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool {
|
|||
pong.StateEvent == newPong.StateEvent &&
|
||||
pong.RemoteName == newPong.RemoteName &&
|
||||
pong.UserAction == newPong.UserAction &&
|
||||
pong.RemoteProfile == newPong.RemoteProfile &&
|
||||
ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) &&
|
||||
pong.Error == newPong.Error &&
|
||||
maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) &&
|
||||
pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now())
|
||||
|
|
|
|||
|
|
@ -176,10 +176,6 @@ func (user *User) GetUserLogins() []*UserLogin {
|
|||
return maps.Values(user.logins)
|
||||
}
|
||||
|
||||
func (user *User) HasTooManyLogins() bool {
|
||||
return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins
|
||||
}
|
||||
|
||||
func (user *User) GetFormattedUserLogins() string {
|
||||
user.Bridge.cacheLock.Lock()
|
||||
logins := make([]string, len(user.logins))
|
||||
|
|
@ -229,8 +225,9 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) {
|
|||
user.MXID: 50,
|
||||
},
|
||||
},
|
||||
Invite: []id.UserID{user.MXID},
|
||||
IsDirect: true,
|
||||
RoomVersion: id.RoomV11,
|
||||
Invite: []id.UserID{user.MXID},
|
||||
IsDirect: true,
|
||||
}
|
||||
if autoJoin {
|
||||
req.BeeperInitialMembers = []id.UserID{user.MXID}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -51,8 +50,6 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
// TODO if loading the user caused the provided userlogin to be loaded, cancel here?
|
||||
// Currently this will double-load it
|
||||
}
|
||||
userLogin := &UserLogin{
|
||||
UserLogin: dbUserLogin,
|
||||
|
|
@ -143,12 +140,6 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin {
|
|||
return br.userLoginsByID[id]
|
||||
}
|
||||
|
||||
func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) {
|
||||
br.cacheLock.Lock()
|
||||
defer br.cacheLock.Unlock()
|
||||
return slices.Collect(maps.Values(br.userLoginsByID))
|
||||
}
|
||||
|
||||
func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) {
|
||||
br.cacheLock.Lock()
|
||||
defer br.cacheLock.Unlock()
|
||||
|
|
@ -510,9 +501,9 @@ var _ status.BridgeStateFiller = (*UserLogin)(nil)
|
|||
|
||||
func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState {
|
||||
state.UserID = ul.UserMXID
|
||||
state.RemoteID = ul.ID
|
||||
state.RemoteID = string(ul.ID)
|
||||
state.RemoteName = ul.RemoteName
|
||||
state.RemoteProfile = ul.RemoteProfile
|
||||
state.RemoteProfile = &ul.RemoteProfile
|
||||
filler, ok := ul.Client.(status.BridgeStateFiller)
|
||||
if ok {
|
||||
return filler.FillBridgeState(state)
|
||||
|
|
|
|||
521
client.go
521
client.go
|
|
@ -111,8 +111,6 @@ type Client struct {
|
|||
// Set to true to disable automatically sleeping on 429 errors.
|
||||
IgnoreRateLimit bool
|
||||
|
||||
ResponseSizeLimit int64
|
||||
|
||||
txnID int32
|
||||
|
||||
// Should the ?user_id= query parameter be set in requests?
|
||||
|
|
@ -145,8 +143,6 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
|
|||
return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName)
|
||||
}
|
||||
|
||||
const WellKnownMaxSize = 64 * 1024
|
||||
|
||||
func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) {
|
||||
wellKnownURL := url.URL{
|
||||
Scheme: "https",
|
||||
|
|
@ -172,15 +168,11 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve
|
|||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
} else if resp.ContentLength > WellKnownMaxSize {
|
||||
return nil, errors.New(".well-known response too large")
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize))
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(data) >= WellKnownMaxSize {
|
||||
return nil, errors.New(".well-known response too large")
|
||||
}
|
||||
|
||||
var wellKnown ClientWellKnown
|
||||
|
|
@ -331,7 +323,6 @@ const (
|
|||
LogBodyContextKey contextKey = iota
|
||||
LogRequestIDContextKey
|
||||
MaxAttemptsContextKey
|
||||
SyncTokenContextKey
|
||||
)
|
||||
|
||||
func (cli *Client) RequestStart(req *http.Request) {
|
||||
|
|
@ -386,14 +377,7 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er
|
|||
}
|
||||
}
|
||||
if body := req.Context().Value(LogBodyContextKey); body != nil {
|
||||
switch typedLogBody := body.(type) {
|
||||
case json.RawMessage:
|
||||
evt.RawJSON("req_body", typedLogBody)
|
||||
case string:
|
||||
evt.Str("req_body", typedLogBody)
|
||||
default:
|
||||
panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body))
|
||||
}
|
||||
evt.Interface("req_body", body)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
evt.Msg("Request canceled")
|
||||
|
|
@ -410,43 +394,32 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin
|
|||
return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody})
|
||||
}
|
||||
|
||||
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error)
|
||||
type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error)
|
||||
|
||||
type FullRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
RequestJSON interface{}
|
||||
RequestBytes []byte
|
||||
RequestBody io.Reader
|
||||
RequestLength int64
|
||||
ResponseJSON interface{}
|
||||
MaxAttempts int
|
||||
BackoffDuration time.Duration
|
||||
SensitiveContent bool
|
||||
Handler ClientResponseHandler
|
||||
DontReadResponse bool
|
||||
ResponseSizeLimit int64
|
||||
Logger *zerolog.Logger
|
||||
Client *http.Client
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
RequestJSON interface{}
|
||||
RequestBytes []byte
|
||||
RequestBody io.Reader
|
||||
RequestLength int64
|
||||
ResponseJSON interface{}
|
||||
MaxAttempts int
|
||||
BackoffDuration time.Duration
|
||||
SensitiveContent bool
|
||||
Handler ClientResponseHandler
|
||||
DontReadResponse bool
|
||||
Logger *zerolog.Logger
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
var requestID int32
|
||||
var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes"
|
||||
|
||||
func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) {
|
||||
reqID := atomic.AddInt32(&requestID, 1)
|
||||
logger := zerolog.Ctx(ctx)
|
||||
if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
|
||||
logger = params.Logger
|
||||
}
|
||||
ctx = logger.With().
|
||||
Int32("req_id", reqID).
|
||||
Logger().WithContext(ctx)
|
||||
|
||||
var logBody any
|
||||
var reqBody io.Reader
|
||||
var reqLen int64
|
||||
reqBody := params.RequestBody
|
||||
if params.RequestJSON != nil {
|
||||
jsonStr, err := json.Marshal(params.RequestJSON)
|
||||
if err != nil {
|
||||
|
|
@ -457,38 +430,33 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
|
|||
}
|
||||
if params.SensitiveContent && !logSensitiveContent {
|
||||
logBody = "<sensitive content omitted>"
|
||||
} else if len(jsonStr) > 32768 {
|
||||
logBody = fmt.Sprintf("<large content omitted (%d bytes)>", len(jsonStr))
|
||||
} else {
|
||||
logBody = json.RawMessage(jsonStr)
|
||||
logBody = params.RequestJSON
|
||||
}
|
||||
reqBody = bytes.NewReader(jsonStr)
|
||||
reqLen = int64(len(jsonStr))
|
||||
} else if params.RequestBytes != nil {
|
||||
logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes))
|
||||
reqBody = bytes.NewReader(params.RequestBytes)
|
||||
reqLen = int64(len(params.RequestBytes))
|
||||
} else if params.RequestBody != nil {
|
||||
logBody = "<unknown stream of bytes>"
|
||||
reqLen = -1
|
||||
if params.RequestLength > 0 {
|
||||
logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
|
||||
reqLen = params.RequestLength
|
||||
} else if params.RequestLength == 0 {
|
||||
zerolog.Ctx(ctx).Warn().
|
||||
Msg("RequestBody passed without specifying request length")
|
||||
}
|
||||
reqBody = params.RequestBody
|
||||
params.RequestLength = int64(len(params.RequestBytes))
|
||||
} else if params.RequestLength > 0 && params.RequestBody != nil {
|
||||
logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
|
||||
if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok {
|
||||
// Prevent HTTP from closing the request body, it might be needed for retries
|
||||
reqBody = nopCloseSeeker{rsc}
|
||||
}
|
||||
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
|
||||
params.RequestJSON = struct{}{}
|
||||
logBody = json.RawMessage("{}")
|
||||
logBody = params.RequestJSON
|
||||
reqBody = bytes.NewReader([]byte("{}"))
|
||||
reqLen = 2
|
||||
}
|
||||
reqID := atomic.AddInt32(&requestID, 1)
|
||||
logger := zerolog.Ctx(ctx)
|
||||
if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
|
||||
logger = params.Logger
|
||||
}
|
||||
ctx = logger.With().
|
||||
Int32("req_id", reqID).
|
||||
Logger().WithContext(ctx)
|
||||
ctx = context.WithValue(ctx, LogBodyContextKey, logBody)
|
||||
ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID))
|
||||
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
|
||||
|
|
@ -504,7 +472,9 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e
|
|||
if params.RequestJSON != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
req.ContentLength = reqLen
|
||||
if params.RequestLength > 0 && params.RequestBody != nil {
|
||||
req.ContentLength = params.RequestLength
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
|
|
@ -555,25 +525,10 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque
|
|||
if len(cli.AccessToken) > 0 {
|
||||
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
|
||||
}
|
||||
if params.ResponseSizeLimit == 0 {
|
||||
params.ResponseSizeLimit = cli.ResponseSizeLimit
|
||||
}
|
||||
if params.ResponseSizeLimit == 0 {
|
||||
params.ResponseSizeLimit = DefaultResponseSizeLimit
|
||||
}
|
||||
if params.Client == nil {
|
||||
params.Client = cli.Client
|
||||
}
|
||||
return cli.executeCompiledRequest(
|
||||
req,
|
||||
params.MaxAttempts-1,
|
||||
params.BackoffDuration,
|
||||
params.ResponseJSON,
|
||||
params.Handler,
|
||||
params.DontReadResponse,
|
||||
params.ResponseSizeLimit,
|
||||
params.Client,
|
||||
)
|
||||
return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client)
|
||||
}
|
||||
|
||||
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
||||
|
|
@ -584,17 +539,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
|||
return log
|
||||
}
|
||||
|
||||
func (cli *Client) doRetry(
|
||||
req *http.Request,
|
||||
cause error,
|
||||
retries int,
|
||||
backoff time.Duration,
|
||||
responseJSON any,
|
||||
handler ClientResponseHandler,
|
||||
dontReadResponse bool,
|
||||
sizeLimit int64,
|
||||
client *http.Client,
|
||||
) ([]byte, *http.Response, error) {
|
||||
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
|
||||
log := zerolog.Ctx(req.Context())
|
||||
if req.Body != nil {
|
||||
var err error
|
||||
|
|
@ -623,30 +568,16 @@ func (cli *Client) doRetry(
|
|||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-req.Context().Done():
|
||||
if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) {
|
||||
return nil, nil, req.Context().Err()
|
||||
}
|
||||
return nil, nil, req.Context().Err()
|
||||
}
|
||||
if cli.UpdateRequestOnRetry != nil {
|
||||
req = cli.UpdateRequestOnRetry(req, cause)
|
||||
}
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client)
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client)
|
||||
}
|
||||
|
||||
func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) {
|
||||
if res.ContentLength > limit {
|
||||
return nil, HTTPError{
|
||||
Request: req,
|
||||
Response: res,
|
||||
|
||||
Message: "not reading response",
|
||||
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
|
||||
}
|
||||
}
|
||||
contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1))
|
||||
if err == nil && len(contents) > int(limit) {
|
||||
err = ErrBodyReadReachedLimit
|
||||
}
|
||||
func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) {
|
||||
contents, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, HTTPError{
|
||||
Request: req,
|
||||
|
|
@ -667,20 +598,17 @@ func closeTemp(log *zerolog.Logger, file *os.File) {
|
|||
}
|
||||
}
|
||||
|
||||
func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
log := zerolog.Ctx(req.Context())
|
||||
file, err := os.CreateTemp("", "mautrix-response-")
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
|
||||
_, err = handleNormalResponse(req, res, responseJSON, limit)
|
||||
_, err = handleNormalResponse(req, res, responseJSON)
|
||||
return nil, err
|
||||
}
|
||||
defer closeTemp(log, file)
|
||||
var n int64
|
||||
if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil {
|
||||
if _, err = io.Copy(file, res.Body); err != nil {
|
||||
return nil, fmt.Errorf("failed to copy response to file: %w", err)
|
||||
} else if n > limit {
|
||||
return nil, ErrBodyReadReachedLimit
|
||||
} else if _, err = file.Seek(0, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err)
|
||||
} else if err = json.NewDecoder(file).Decode(responseJSON); err != nil {
|
||||
|
|
@ -690,12 +618,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON any, lim
|
|||
}
|
||||
}
|
||||
|
||||
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
if contents, err := readResponseBody(req, res, limit); err != nil {
|
||||
func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
if contents, err := readResponseBody(req, res); err != nil {
|
||||
return nil, err
|
||||
} else if responseJSON == nil {
|
||||
return contents, nil
|
||||
|
|
@ -713,13 +641,8 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON an
|
|||
}
|
||||
}
|
||||
|
||||
const ErrorResponseSizeLimit = 512 * 1024
|
||||
|
||||
var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024
|
||||
|
||||
func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
|
||||
defer res.Body.Close()
|
||||
contents, err := readResponseBody(req, res, ErrorResponseSizeLimit)
|
||||
contents, err := readResponseBody(req, res)
|
||||
if err != nil {
|
||||
return contents, err
|
||||
}
|
||||
|
|
@ -738,31 +661,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (cli *Client) executeCompiledRequest(
|
||||
req *http.Request,
|
||||
retries int,
|
||||
backoff time.Duration,
|
||||
responseJSON any,
|
||||
handler ClientResponseHandler,
|
||||
dontReadResponse bool,
|
||||
sizeLimit int64,
|
||||
client *http.Client,
|
||||
) ([]byte, *http.Response, error) {
|
||||
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) {
|
||||
cli.RequestStart(req)
|
||||
startTime := time.Now()
|
||||
res, err := client.Do(req)
|
||||
duration := time.Since(startTime)
|
||||
duration := time.Now().Sub(startTime)
|
||||
if res != nil && !dontReadResponse {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
// Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry
|
||||
canRetry := !errors.Is(err, context.Canceled) ||
|
||||
errors.Is(context.Cause(req.Context()), ErrContextCancelRetry)
|
||||
if retries > 0 && canRetry {
|
||||
return cli.doRetry(
|
||||
req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
|
||||
)
|
||||
if retries > 0 && !errors.Is(err, context.Canceled) {
|
||||
return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client)
|
||||
}
|
||||
err = HTTPError{
|
||||
Request: req,
|
||||
|
|
@ -777,9 +686,7 @@ func (cli *Client) executeCompiledRequest(
|
|||
|
||||
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
|
||||
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
|
||||
return cli.doRetry(
|
||||
req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client,
|
||||
)
|
||||
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client)
|
||||
}
|
||||
|
||||
var body []byte
|
||||
|
|
@ -787,7 +694,7 @@ func (cli *Client) executeCompiledRequest(
|
|||
body, err = ParseErrorResponse(req, res)
|
||||
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
|
||||
} else {
|
||||
body, err = handler(req, res, responseJSON, sizeLimit)
|
||||
body, err = handler(req, res, responseJSON)
|
||||
cli.LogRequestDone(req, res, nil, err, len(body), duration)
|
||||
}
|
||||
return body, res, err
|
||||
|
|
@ -847,7 +754,7 @@ func (req *ReqSync) BuildQuery() map[string]string {
|
|||
query["full_state"] = "true"
|
||||
}
|
||||
if req.UseStateAfter {
|
||||
query["use_state_after"] = "true"
|
||||
query["org.matrix.msc4222.use_state_after"] = "true"
|
||||
}
|
||||
if req.BeeperStreaming {
|
||||
query["com.beeper.streaming"] = "true"
|
||||
|
|
@ -871,7 +778,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp
|
|||
}
|
||||
start := time.Now()
|
||||
_, err = cli.MakeFullRequest(ctx, fullReq)
|
||||
duration := time.Since(start)
|
||||
duration := time.Now().Sub(start)
|
||||
timeout := time.Duration(req.Timeout) * time.Millisecond
|
||||
buffer := 10 * time.Second
|
||||
if req.Since == "" {
|
||||
|
|
@ -918,7 +825,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp
|
|||
return
|
||||
}
|
||||
|
||||
func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
|
||||
func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) {
|
||||
var bodyBytes []byte
|
||||
bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{
|
||||
Method: http.MethodPost,
|
||||
|
|
@ -942,7 +849,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[an
|
|||
// Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register
|
||||
//
|
||||
// Registers with kind=user. For kind=guest, see RegisterGuest.
|
||||
func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
|
||||
func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
|
||||
u := cli.BuildClientURL("v3", "register")
|
||||
return cli.register(ctx, u, req)
|
||||
}
|
||||
|
|
@ -951,7 +858,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRe
|
|||
// with kind=guest.
|
||||
//
|
||||
// For kind=user, see Register.
|
||||
func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) {
|
||||
func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) {
|
||||
query := map[string]string{
|
||||
"kind": "guest",
|
||||
}
|
||||
|
|
@ -974,8 +881,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*R
|
|||
// panic(err)
|
||||
// }
|
||||
// token := res.AccessToken
|
||||
func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) {
|
||||
_, uia, err := cli.Register(ctx, req)
|
||||
func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) {
|
||||
res, uia, err := cli.Register(ctx, req)
|
||||
if err != nil && uia == nil {
|
||||
return nil, err
|
||||
} else if uia == nil {
|
||||
|
|
@ -984,7 +891,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*R
|
|||
return nil, errors.New("server does not support m.login.dummy")
|
||||
}
|
||||
req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session}
|
||||
res, _, err := cli.Register(ctx, req)
|
||||
res, _, err = cli.Register(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1148,19 +1055,8 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs
|
|||
return
|
||||
}
|
||||
|
||||
func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "user_directory", "search")
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{
|
||||
SearchTerm: query,
|
||||
Limit: limit,
|
||||
}, &resp)
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) {
|
||||
supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms)
|
||||
supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms)
|
||||
if cli.SpecVersions != nil && !supportsUnstable && !supportsStable {
|
||||
if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) {
|
||||
err = fmt.Errorf("server does not support fetching mutual rooms")
|
||||
return
|
||||
}
|
||||
|
|
@ -1170,10 +1066,7 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex
|
|||
if len(extras) > 0 {
|
||||
query["from"] = extras[0].From
|
||||
}
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query)
|
||||
if !supportsStable && supportsUnstable {
|
||||
urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
|
||||
}
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
}
|
||||
|
|
@ -1195,7 +1088,8 @@ func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via
|
|||
|
||||
// GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname
|
||||
func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) {
|
||||
err = cli.GetProfileField(ctx, mxid, "displayname", &resp)
|
||||
urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname")
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -1206,47 +1100,41 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay
|
|||
|
||||
// SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname
|
||||
func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) {
|
||||
return cli.SetProfileField(ctx, "displayname", displayName)
|
||||
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname")
|
||||
s := struct {
|
||||
DisplayName string `json:"displayname"`
|
||||
}{displayName}
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
|
||||
func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
|
||||
if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
|
||||
urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
|
||||
}
|
||||
// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
|
||||
func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) {
|
||||
urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{
|
||||
key: value,
|
||||
}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname
|
||||
func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key)
|
||||
if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
|
||||
urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
|
||||
}
|
||||
// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133
|
||||
func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) {
|
||||
urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname
|
||||
func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "profile", userID, key)
|
||||
if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) {
|
||||
urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key)
|
||||
}
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into)
|
||||
return
|
||||
}
|
||||
|
||||
// GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url
|
||||
func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url")
|
||||
s := struct {
|
||||
AvatarURL id.ContentURI `json:"avatar_url"`
|
||||
}{}
|
||||
err = cli.GetProfileField(ctx, mxid, "avatar_url", &s)
|
||||
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
url = s.AvatarURL
|
||||
return
|
||||
}
|
||||
|
|
@ -1338,9 +1226,6 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
|
|||
if req.UnstableDelay > 0 {
|
||||
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
|
||||
}
|
||||
if req.UnstableStickyDuration > 0 {
|
||||
queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
|
||||
}
|
||||
|
||||
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
|
||||
var isEncrypted bool
|
||||
|
|
@ -1364,51 +1249,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
|
|||
return
|
||||
}
|
||||
|
||||
// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint.
|
||||
// contentJSON should be a value that can be encoded as JSON using json.Marshal.
|
||||
func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
|
||||
var req ReqSendEvent
|
||||
if len(extra) > 0 {
|
||||
req = extra[0]
|
||||
}
|
||||
|
||||
var txnID string
|
||||
if len(req.TransactionID) > 0 {
|
||||
txnID = req.TransactionID
|
||||
} else {
|
||||
txnID = cli.TxnID()
|
||||
}
|
||||
|
||||
queryParams := map[string]string{}
|
||||
if req.Timestamp > 0 {
|
||||
queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
|
||||
}
|
||||
|
||||
if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted {
|
||||
var isEncrypted bool
|
||||
isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to check if room is encrypted: %w", err)
|
||||
return
|
||||
}
|
||||
if isEncrypted {
|
||||
if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event: %w", err)
|
||||
return
|
||||
}
|
||||
eventType = event.EventEncrypted
|
||||
}
|
||||
}
|
||||
|
||||
urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID}
|
||||
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
|
||||
return
|
||||
}
|
||||
|
||||
// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
|
||||
// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
|
||||
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
|
||||
func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
|
||||
func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) {
|
||||
var req ReqSendEvent
|
||||
if len(extra) > 0 {
|
||||
req = extra[0]
|
||||
|
|
@ -1418,18 +1261,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
|
|||
if req.MeowEventID != "" {
|
||||
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
|
||||
}
|
||||
if req.TransactionID != "" {
|
||||
queryParams["fi.mau.transaction_id"] = req.TransactionID
|
||||
}
|
||||
if req.UnstableDelay > 0 {
|
||||
queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10)
|
||||
}
|
||||
if req.UnstableStickyDuration > 0 {
|
||||
queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10)
|
||||
}
|
||||
if req.Timestamp > 0 {
|
||||
queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10)
|
||||
}
|
||||
|
||||
urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}
|
||||
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
|
||||
|
|
@ -1442,38 +1276,14 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
|
|||
|
||||
// SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey
|
||||
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
|
||||
//
|
||||
// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead.
|
||||
func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
|
||||
resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{
|
||||
Timestamp: ts,
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
|
||||
"ts": strconv.FormatInt(ts, 10),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) {
|
||||
query := map[string]string{}
|
||||
if req.DelayID != "" {
|
||||
query["delay_id"] = string(req.DelayID)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
|
||||
}
|
||||
if req.Status != "" {
|
||||
query["status"] = string(req.Status)
|
||||
}
|
||||
if req.NextBatch != "" {
|
||||
query["next_batch"] = req.NextBatch
|
||||
}
|
||||
|
||||
urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp)
|
||||
|
||||
// Migration: merge old keys with new ones
|
||||
if resp != nil {
|
||||
resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...)
|
||||
resp.DelayedEvents = nil
|
||||
resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...)
|
||||
resp.FinalisedEvents = nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -1766,20 +1576,11 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy
|
|||
}
|
||||
|
||||
// parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map.
|
||||
func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) {
|
||||
if res.ContentLength > limit {
|
||||
return nil, HTTPError{
|
||||
Request: req,
|
||||
Response: res,
|
||||
|
||||
Message: "not reading response",
|
||||
WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024),
|
||||
}
|
||||
}
|
||||
func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
response := make(RoomStateMap)
|
||||
responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event)
|
||||
*responsePtr = response
|
||||
dec := json.NewDecoder(io.LimitReader(res.Body, limit))
|
||||
dec := json.NewDecoder(res.Body)
|
||||
|
||||
arrayStart, err := dec.Token()
|
||||
if err != nil {
|
||||
|
|
@ -1813,8 +1614,6 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
type RoomStateMap = map[event.Type]map[string]*event.Event
|
||||
|
||||
// State gets all state in a room.
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate
|
||||
func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) {
|
||||
|
|
@ -1897,9 +1696,6 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa
|
|||
}
|
||||
|
||||
func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
|
||||
if mxcURL.IsEmpty() {
|
||||
return nil, fmt.Errorf("empty mxc uri provided to Download")
|
||||
}
|
||||
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
|
||||
Method: http.MethodGet,
|
||||
URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID),
|
||||
|
|
@ -1908,41 +1704,6 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re
|
|||
return resp, err
|
||||
}
|
||||
|
||||
type DownloadThumbnailExtra struct {
|
||||
Method string
|
||||
Animated bool
|
||||
}
|
||||
|
||||
func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) {
|
||||
if mxcURL.IsEmpty() {
|
||||
return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail")
|
||||
}
|
||||
if len(extras) > 1 {
|
||||
panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras)))
|
||||
}
|
||||
var extra DownloadThumbnailExtra
|
||||
if len(extras) == 1 {
|
||||
extra = extras[0]
|
||||
}
|
||||
path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID}
|
||||
query := map[string]string{
|
||||
"height": strconv.Itoa(height),
|
||||
"width": strconv.Itoa(width),
|
||||
}
|
||||
if extra.Method != "" {
|
||||
query["method"] = extra.Method
|
||||
}
|
||||
if extra.Animated {
|
||||
query["animated"] = "true"
|
||||
}
|
||||
_, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{
|
||||
Method: http.MethodGet,
|
||||
URL: cli.BuildURLWithQuery(path, query),
|
||||
DontReadResponse: true,
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) {
|
||||
resp, err := cli.Download(ctx, mxcURL)
|
||||
if err != nil {
|
||||
|
|
@ -1989,15 +1750,10 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr
|
|||
}
|
||||
req.MXC = resp.ContentURI
|
||||
req.UnstableUploadURL = resp.UnstableUploadURL
|
||||
if req.AsyncContext == nil {
|
||||
req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background())
|
||||
}
|
||||
go func() {
|
||||
_, err = cli.UploadMedia(req.AsyncContext, req)
|
||||
_, err = cli.UploadMedia(ctx, req)
|
||||
if err != nil {
|
||||
zerolog.Ctx(req.AsyncContext).Err(err).
|
||||
Stringer("mxc", req.MXC).
|
||||
Msg("Async upload of media failed")
|
||||
cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed")
|
||||
}
|
||||
}()
|
||||
return resp, nil
|
||||
|
|
@ -2033,7 +1789,6 @@ type ReqUploadMedia struct {
|
|||
ContentType string
|
||||
FileName string
|
||||
|
||||
AsyncContext context.Context
|
||||
DoneCallback func()
|
||||
|
||||
// MXC specifies an existing MXC URI which doesn't have content yet to upload into.
|
||||
|
|
@ -2046,10 +1801,7 @@ type ReqUploadMedia struct {
|
|||
}
|
||||
|
||||
func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) {
|
||||
cli.Log.Debug().
|
||||
Str("url", url).
|
||||
Int64("content_length", contentLength).
|
||||
Msg("Uploading media to external URL")
|
||||
cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -2098,16 +1850,8 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*
|
|||
Msg("Error uploading media to external URL, not retrying")
|
||||
return nil, err
|
||||
}
|
||||
backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries)
|
||||
cli.Log.Warn().Err(err).
|
||||
Str("url", data.UnstableUploadURL).
|
||||
Int("retry_in_seconds", int(backoff.Seconds())).
|
||||
cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err).
|
||||
Msg("Error uploading media to external URL, retrying")
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
retries--
|
||||
_, err = readerSeeker.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
|
|
@ -2687,15 +2431,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req
|
|||
return err
|
||||
}
|
||||
|
||||
func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error {
|
||||
func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error {
|
||||
urlPath := cli.BuildClientURL("v3", "devices", deviceID)
|
||||
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error {
|
||||
func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error {
|
||||
urlPath := cli.BuildClientURL("v3", "delete_devices")
|
||||
_, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil)
|
||||
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -2704,7 +2448,7 @@ type UIACallback = func(*RespUserInteractive) interface{}
|
|||
// UploadCrossSigningKeys uploads the given cross-signing keys to the server.
|
||||
// Because the endpoint requires user-interactive authentication a callback must be provided that,
|
||||
// given the UI auth parameters, produces the required result (or nil to end the flow).
|
||||
func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error {
|
||||
func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error {
|
||||
content, err := cli.MakeFullRequest(ctx, FullRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"),
|
||||
|
|
@ -2786,61 +2530,24 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri
|
|||
return err
|
||||
}
|
||||
|
||||
// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin.
|
||||
// BatchSend sends a batch of historical events into a room. This is only available for appservices.
|
||||
//
|
||||
// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid
|
||||
func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "admin", "whois", userID)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) makeMSC4323URL(action string, target id.UserID) string {
|
||||
if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) {
|
||||
return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target)
|
||||
} else if cli.SpecVersions.Supports(FeatureStableAccountModeration) {
|
||||
return cli.BuildClientURL("v1", "admin", action, target)
|
||||
// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead.
|
||||
func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) {
|
||||
path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"}
|
||||
query := map[string]string{
|
||||
"prev_event_id": req.PrevEventID.String(),
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSuspendedStatus uses MSC4323 to check if a user is suspended.
|
||||
func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) {
|
||||
urlPath := cli.makeMSC4323URL("suspend", userID)
|
||||
if urlPath == "" {
|
||||
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
|
||||
if req.BeeperNewMessages {
|
||||
query["com.beeper.new_messages"] = "true"
|
||||
}
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
|
||||
return
|
||||
}
|
||||
|
||||
// GetLockStatus uses MSC4323 to check if a user is locked.
|
||||
func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) {
|
||||
urlPath := cli.makeMSC4323URL("lock", userID)
|
||||
if urlPath == "" {
|
||||
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
|
||||
if req.BeeperMarkReadBy != "" {
|
||||
query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String()
|
||||
}
|
||||
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res)
|
||||
return
|
||||
}
|
||||
|
||||
// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended.
|
||||
func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) {
|
||||
urlPath := cli.makeMSC4323URL("suspend", userID)
|
||||
if urlPath == "" {
|
||||
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
|
||||
if len(req.BatchID) > 0 {
|
||||
query["batch_id"] = req.BatchID.String()
|
||||
}
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res)
|
||||
return
|
||||
}
|
||||
|
||||
// SetLockStatus uses MSC4323 to set whether a user account is locked.
|
||||
func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) {
|
||||
urlPath := cli.makeMSC4323URL("lock", userID)
|
||||
if urlPath == "" {
|
||||
return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support")
|
||||
}
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res)
|
||||
_, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,158 +0,0 @@
|
|||
// Copyright (c) 2026 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mautrix_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) {
|
||||
roomID := id.RoomID("!room:example.com")
|
||||
evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
|
||||
txnID := "txn-123"
|
||||
|
||||
var gotPath string
|
||||
var gotQueryTS string
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotQueryTS = r.URL.Query().Get("ts")
|
||||
assert.Equal(t, http.MethodPut, r.Method)
|
||||
_, _ = w.Write([]byte(`{"event_id":"$evt"}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
cli, err := mautrix.NewClient(ts.URL, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cli.BeeperSendEphemeralEvent(
|
||||
context.Background(),
|
||||
roomID,
|
||||
evtType,
|
||||
map[string]any{"foo": "bar"},
|
||||
mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/"))
|
||||
assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID))
|
||||
assert.Equal(t, "1234", gotQueryTS)
|
||||
}
|
||||
|
||||
func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
cli, err := mautrix.NewClient(ts.URL, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cli.BeeperSendEphemeralEvent(
|
||||
context.Background(),
|
||||
id.RoomID("!room:example.com"),
|
||||
event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType},
|
||||
map[string]any{"foo": "bar"},
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, mautrix.MUnrecognized))
|
||||
}
|
||||
|
||||
func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) {
|
||||
roomID := id.RoomID("!room:example.com")
|
||||
evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}
|
||||
txnID := "txn-encrypted"
|
||||
|
||||
stateStore := mautrix.NewMemoryStateStore()
|
||||
err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeCrypto := &fakeCryptoHelper{
|
||||
encryptedContent: &event.EncryptedEventContent{
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
MegolmCiphertext: []byte("ciphertext"),
|
||||
},
|
||||
}
|
||||
|
||||
var gotPath string
|
||||
var gotBody map[string]any
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
assert.Equal(t, http.MethodPut, r.Method)
|
||||
err := json.NewDecoder(r.Body).Decode(&gotBody)
|
||||
require.NoError(t, err)
|
||||
_, _ = w.Write([]byte(`{"event_id":"$evt"}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
cli, err := mautrix.NewClient(ts.URL, "", "")
|
||||
require.NoError(t, err)
|
||||
cli.StateStore = stateStore
|
||||
cli.Crypto = fakeCrypto
|
||||
|
||||
_, err = cli.BeeperSendEphemeralEvent(
|
||||
context.Background(),
|
||||
roomID,
|
||||
evtType,
|
||||
map[string]any{"foo": "bar"},
|
||||
mautrix.ReqSendEvent{TransactionID: txnID},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID))
|
||||
assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"])
|
||||
assert.Equal(t, 1, fakeCrypto.encryptCalls)
|
||||
assert.Equal(t, roomID, fakeCrypto.lastRoomID)
|
||||
assert.Equal(t, evtType, fakeCrypto.lastEventType)
|
||||
}
|
||||
|
||||
type fakeCryptoHelper struct {
|
||||
encryptCalls int
|
||||
lastRoomID id.RoomID
|
||||
lastEventType event.Type
|
||||
lastEncryptInput any
|
||||
encryptedContent *event.EncryptedEventContent
|
||||
}
|
||||
|
||||
func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) {
|
||||
f.encryptCalls++
|
||||
f.lastRoomID = roomID
|
||||
f.lastEventType = eventType
|
||||
f.lastEncryptInput = content
|
||||
return f.encryptedContent, nil
|
||||
}
|
||||
|
||||
func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) {
|
||||
}
|
||||
|
||||
func (f *fakeCryptoHelper) Init(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2026 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -8,20 +8,14 @@ package commands
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/util/exmaps"
|
||||
|
||||
"maunium.net/go/mautrix/event/cmdschema"
|
||||
)
|
||||
|
||||
type CommandContainer[MetaType any] struct {
|
||||
commands map[string]*Handler[MetaType]
|
||||
aliases map[string]string
|
||||
lock sync.RWMutex
|
||||
parent *Handler[MetaType]
|
||||
}
|
||||
|
||||
func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
|
||||
|
|
@ -31,29 +25,6 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
|
|||
}
|
||||
}
|
||||
|
||||
func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent {
|
||||
data := make(exmaps.Set[*Handler[MetaType]])
|
||||
cont.collectHandlers(data)
|
||||
specs := make([]*cmdschema.EventContent, 0, data.Size())
|
||||
for handler := range data.Iter() {
|
||||
if handler.Parameters != nil {
|
||||
specs = append(specs, handler.Spec())
|
||||
}
|
||||
}
|
||||
return specs
|
||||
}
|
||||
|
||||
func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) {
|
||||
cont.lock.RLock()
|
||||
defer cont.lock.RUnlock()
|
||||
for _, handler := range cont.commands {
|
||||
into.Add(handler)
|
||||
if handler.subcommandContainer != nil {
|
||||
handler.subcommandContainer.collectHandlers(into)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers the given command handlers.
|
||||
func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) {
|
||||
if cont == nil {
|
||||
|
|
@ -61,10 +32,7 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType])
|
|||
}
|
||||
cont.lock.Lock()
|
||||
defer cont.lock.Unlock()
|
||||
for i, handler := range handlers {
|
||||
if handler == nil {
|
||||
panic(fmt.Errorf("handler #%d is nil", i+1))
|
||||
}
|
||||
for _, handler := range handlers {
|
||||
cont.registerOne(handler)
|
||||
}
|
||||
}
|
||||
|
|
@ -77,10 +45,6 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType])
|
|||
} else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists {
|
||||
panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget))
|
||||
}
|
||||
if !slices.Contains(handler.parents, cont.parent) {
|
||||
handler.parents = append(handler.parents, cont.parent)
|
||||
handler.nestedNameCache = nil
|
||||
}
|
||||
cont.commands[handler.Name] = handler
|
||||
for _, alias := range handler.Aliases {
|
||||
if strings.ToLower(alias) != alias {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2026 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -8,7 +8,6 @@ package commands
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
|
|
@ -36,8 +35,6 @@ type Event[MetaType any] struct {
|
|||
// RawArgs is the same as args, but without the splitting by whitespace.
|
||||
RawArgs string
|
||||
|
||||
StructuredArgs json.RawMessage
|
||||
|
||||
Ctx context.Context
|
||||
Log *zerolog.Logger
|
||||
Proc *Processor[MetaType]
|
||||
|
|
@ -64,7 +61,7 @@ var IDHTMLParser = &format.HTMLParser{
|
|||
}
|
||||
|
||||
// ParseEvent parses a message into a command event struct.
|
||||
func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] {
|
||||
func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] {
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" {
|
||||
return nil
|
||||
|
|
@ -73,34 +70,12 @@ func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Even
|
|||
if content.Format == event.FormatHTML {
|
||||
text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx))
|
||||
}
|
||||
if content.MSC4391BotCommand != nil {
|
||||
if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 {
|
||||
return nil
|
||||
}
|
||||
wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand)
|
||||
wrapped.RawInput = text
|
||||
return wrapped
|
||||
}
|
||||
if len(text) == 0 {
|
||||
return nil
|
||||
}
|
||||
return RawTextToEvent[MetaType](ctx, evt, text)
|
||||
}
|
||||
|
||||
func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] {
|
||||
commandParts := strings.Split(content.Command, " ")
|
||||
return &Event[MetaType]{
|
||||
Event: evt,
|
||||
// Fake a command and args to let the subcommand finder in Process work.
|
||||
Command: commandParts[0],
|
||||
Args: commandParts[1:],
|
||||
Ctx: ctx,
|
||||
Log: zerolog.Ctx(ctx),
|
||||
|
||||
StructuredArgs: content.Arguments,
|
||||
}
|
||||
}
|
||||
|
||||
func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] {
|
||||
parts := strings.Fields(text)
|
||||
if len(parts) == 0 {
|
||||
|
|
@ -213,25 +188,3 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) {
|
|||
evt.RawArgs = arg + " " + evt.RawArgs
|
||||
evt.Args = append([]string{arg}, evt.Args...)
|
||||
}
|
||||
|
||||
func (evt *Event[MetaType]) ParseArgs(into any) error {
|
||||
return json.Unmarshal(evt.StructuredArgs, into)
|
||||
}
|
||||
|
||||
func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) {
|
||||
err = evt.ParseArgs(&into)
|
||||
return
|
||||
}
|
||||
|
||||
func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) {
|
||||
return func(evt *Event[MetaType]) {
|
||||
parsed, err := ParseArgs[T, MetaType](evt)
|
||||
if err != nil {
|
||||
evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct")
|
||||
// TODO better error, usage info? deduplicate with Process
|
||||
evt.Reply("Failed to parse arguments: %v", err)
|
||||
return
|
||||
}
|
||||
fn(evt, parsed)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2026 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -8,9 +8,6 @@ package commands
|
|||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/event/cmdschema"
|
||||
)
|
||||
|
||||
type Handler[MetaType any] struct {
|
||||
|
|
@ -28,63 +25,12 @@ type Handler[MetaType any] struct {
|
|||
// Event.ShiftArg will likely be useful for implementing such parameters.
|
||||
PreFunc func(ce *Event[MetaType])
|
||||
|
||||
// Description is a short description of the command.
|
||||
Description *event.ExtensibleTextContainer
|
||||
// Parameters is a description of structured command parameters.
|
||||
// If set, the StructuredArgs field of Event will be populated.
|
||||
Parameters []*cmdschema.Parameter
|
||||
TailParam string
|
||||
|
||||
parents []*Handler[MetaType]
|
||||
nestedNameCache []string
|
||||
subcommandContainer *CommandContainer[MetaType]
|
||||
}
|
||||
|
||||
func (h *Handler[MetaType]) NestedNames() []string {
|
||||
if h.nestedNameCache != nil {
|
||||
return h.nestedNameCache
|
||||
}
|
||||
nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents))
|
||||
for _, parent := range h.parents {
|
||||
if parent == nil {
|
||||
nestedNames = append(nestedNames, h.Name)
|
||||
nestedNames = append(nestedNames, h.Aliases...)
|
||||
} else {
|
||||
for _, parentName := range parent.NestedNames() {
|
||||
nestedNames = append(nestedNames, parentName+" "+h.Name)
|
||||
for _, alias := range h.Aliases {
|
||||
nestedNames = append(nestedNames, parentName+" "+alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
h.nestedNameCache = nestedNames
|
||||
return nestedNames
|
||||
}
|
||||
|
||||
func (h *Handler[MetaType]) Spec() *cmdschema.EventContent {
|
||||
names := h.NestedNames()
|
||||
return &cmdschema.EventContent{
|
||||
Command: names[0],
|
||||
Aliases: names[1:],
|
||||
Parameters: h.Parameters,
|
||||
Description: h.Description,
|
||||
TailParam: h.TailParam,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) {
|
||||
if h.Parameters == nil {
|
||||
h.Parameters = other.Parameters
|
||||
h.TailParam = other.TailParam
|
||||
}
|
||||
h.Func = other.Func
|
||||
}
|
||||
|
||||
func (h *Handler[MetaType]) initSubcommandContainer() {
|
||||
if len(h.Subcommands) > 0 {
|
||||
h.subcommandContainer = NewCommandContainer[MetaType]()
|
||||
h.subcommandContainer.parent = h
|
||||
h.subcommandContainer.Register(h.Subcommands...)
|
||||
} else {
|
||||
h.subcommandContainer = nil
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2026 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
|
|||
case event.EventReaction:
|
||||
parsed = proc.ParseReaction(ctx, evt)
|
||||
case event.EventMessage:
|
||||
parsed = proc.ParseEvent(ctx, evt)
|
||||
parsed = ParseEvent[MetaType](ctx, evt)
|
||||
}
|
||||
if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) {
|
||||
if parsed == nil || !proc.PreValidator.Validate(parsed) {
|
||||
return
|
||||
}
|
||||
parsed.Proc = proc
|
||||
|
|
@ -107,12 +107,6 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
|
|||
break
|
||||
}
|
||||
}
|
||||
if parsed.StructuredArgs != nil && len(parsed.Args) > 0 {
|
||||
// TODO allow unknown command handlers to be called?
|
||||
// The client sent MSC4391 data, but the target command wasn't found
|
||||
log.Debug().Msg("Didn't find handler for MSC4391 command")
|
||||
return
|
||||
}
|
||||
|
||||
logWith := log.With().
|
||||
Str("command", parsed.Command).
|
||||
|
|
@ -122,31 +116,11 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
|
|||
}
|
||||
if proc.LogArgs {
|
||||
logWith = logWith.Strs("args", parsed.Args)
|
||||
if parsed.StructuredArgs != nil {
|
||||
logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs)
|
||||
}
|
||||
}
|
||||
log = logWith.Logger()
|
||||
parsed.Ctx = log.WithContext(ctx)
|
||||
parsed.Log = &log
|
||||
|
||||
if handler.Parameters != nil && parsed.StructuredArgs == nil {
|
||||
// The handler wants structured parameters, but the client didn't send MSC4391 data
|
||||
var err error
|
||||
parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Failed to parse structured arguments")
|
||||
// TODO better error, usage info? deduplicate with WithParsedArgs
|
||||
parsed.Reply("Failed to parse arguments: %v", err)
|
||||
return
|
||||
}
|
||||
if proc.LogArgs {
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.RawJSON("structured_args", parsed.StructuredArgs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msg("Processing command")
|
||||
handler.Func(parsed)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2026 Tulir Asokan
|
||||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
@ -8,7 +8,6 @@ package commands
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
|
@ -20,11 +19,6 @@ import (
|
|||
const ReactionCommandsKey = "fi.mau.reaction_commands"
|
||||
const ReactionMultiUseKey = "fi.mau.reaction_multi_use"
|
||||
|
||||
type ReactionCommandData struct {
|
||||
Command string `json:"command"`
|
||||
Args any `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] {
|
||||
content, ok := evt.Content.Parsed.(*event.ReactionEventContent)
|
||||
if !ok {
|
||||
|
|
@ -73,33 +67,21 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E
|
|||
Msg("Reaction command not found in target event")
|
||||
return nil
|
||||
}
|
||||
var wrappedEvt *Event[MetaType]
|
||||
switch typedCmd := rawCmd.(type) {
|
||||
case string:
|
||||
wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd)
|
||||
case map[string]any:
|
||||
var input event.MSC4391BotCommandInput
|
||||
if marshaled, err := json.Marshal(typedCmd); err != nil {
|
||||
|
||||
} else if err = json.Unmarshal(marshaled, &input); err != nil {
|
||||
|
||||
} else {
|
||||
wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input)
|
||||
}
|
||||
}
|
||||
if wrappedEvt == nil {
|
||||
cmdString, ok := rawCmd.(string)
|
||||
if !ok {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Stringer("target_event_id", evtID).
|
||||
Str("reaction_key", content.RelatesTo.Key).
|
||||
Msg("Reaction command data is invalid")
|
||||
return nil
|
||||
}
|
||||
wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString)
|
||||
wrappedEvt.Proc = proc
|
||||
wrappedEvt.Redact()
|
||||
if !isMultiUse {
|
||||
DeleteAllReactions(ctx, proc.Client, evt)
|
||||
}
|
||||
if wrappedEvt.Command == "" {
|
||||
if cmdString == "" {
|
||||
return nil
|
||||
}
|
||||
return wrappedEvt
|
||||
|
|
|
|||
|
|
@ -21,24 +21,13 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrHashMismatch = errors.New("mismatching SHA-256 digest")
|
||||
ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version")
|
||||
ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
|
||||
ErrInvalidKey = errors.New("failed to decode key")
|
||||
ErrInvalidInitVector = errors.New("failed to decode initialization vector")
|
||||
ErrInvalidHash = errors.New("failed to decode SHA-256 hash")
|
||||
ErrReaderClosed = errors.New("encrypting reader was already closed")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
HashMismatch = ErrHashMismatch
|
||||
UnsupportedVersion = ErrUnsupportedVersion
|
||||
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
|
||||
InvalidKey = ErrInvalidKey
|
||||
InvalidInitVector = ErrInvalidInitVector
|
||||
InvalidHash = ErrInvalidHash
|
||||
ReaderClosed = ErrReaderClosed
|
||||
HashMismatch = errors.New("mismatching SHA-256 digest")
|
||||
UnsupportedVersion = errors.New("unsupported Matrix file encryption version")
|
||||
UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm")
|
||||
InvalidKey = errors.New("failed to decode key")
|
||||
InvalidInitVector = errors.New("failed to decode initialization vector")
|
||||
InvalidHash = errors.New("failed to decode SHA-256 hash")
|
||||
ReaderClosed = errors.New("encrypting reader was already closed")
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -96,25 +85,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error {
|
|||
if ef.decoded != nil {
|
||||
return nil
|
||||
} else if len(ef.Key.Key) != keyBase64Length {
|
||||
return ErrInvalidKey
|
||||
return InvalidKey
|
||||
} else if len(ef.InitVector) != ivBase64Length {
|
||||
return ErrInvalidInitVector
|
||||
return InvalidInitVector
|
||||
} else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length {
|
||||
return ErrInvalidHash
|
||||
return InvalidHash
|
||||
}
|
||||
ef.decoded = &decodedKeys{}
|
||||
_, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key))
|
||||
if err != nil {
|
||||
return ErrInvalidKey
|
||||
return InvalidKey
|
||||
}
|
||||
_, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector))
|
||||
if err != nil {
|
||||
return ErrInvalidInitVector
|
||||
return InvalidInitVector
|
||||
}
|
||||
if includeHash {
|
||||
_, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256))
|
||||
if err != nil {
|
||||
return ErrInvalidHash
|
||||
return InvalidHash
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -190,7 +179,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil)
|
|||
|
||||
func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
|
||||
if r.closed {
|
||||
return 0, ErrReaderClosed
|
||||
return 0, ReaderClosed
|
||||
}
|
||||
if offset != 0 || whence != io.SeekStart {
|
||||
return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported")
|
||||
|
|
@ -211,7 +200,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) {
|
|||
|
||||
func (r *encryptingReader) Read(dst []byte) (n int, err error) {
|
||||
if r.closed {
|
||||
return 0, ErrReaderClosed
|
||||
return 0, ReaderClosed
|
||||
} else if r.isDecrypting && r.file.decoded == nil {
|
||||
if err = r.file.PrepareForDecryption(); err != nil {
|
||||
return
|
||||
|
|
@ -235,7 +224,7 @@ func (r *encryptingReader) Close() (err error) {
|
|||
}
|
||||
if r.isDecrypting {
|
||||
if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) {
|
||||
return ErrHashMismatch
|
||||
return HashMismatch
|
||||
}
|
||||
} else {
|
||||
r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil))
|
||||
|
|
@ -276,9 +265,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) {
|
|||
// DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function.
|
||||
func (ef *EncryptedFile) PrepareForDecryption() error {
|
||||
if ef.Version != "v2" {
|
||||
return ErrUnsupportedVersion
|
||||
return UnsupportedVersion
|
||||
} else if ef.Key.Algorithm != "A256CTR" {
|
||||
return ErrUnsupportedAlgorithm
|
||||
return UnsupportedAlgorithm
|
||||
} else if err := ef.decodeKeys(true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -292,7 +281,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error {
|
|||
}
|
||||
dataHash := sha256.Sum256(data)
|
||||
if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) {
|
||||
return ErrHashMismatch
|
||||
return HashMismatch
|
||||
}
|
||||
utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv)
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) {
|
|||
file := parseHelloWorld()
|
||||
file.Version = "foo"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, ErrUnsupportedVersion)
|
||||
assert.ErrorIs(t, err, UnsupportedVersion)
|
||||
}
|
||||
|
||||
func TestUnsupportedAlgorithm(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Key.Algorithm = "bar"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
|
||||
assert.ErrorIs(t, err, UnsupportedAlgorithm)
|
||||
}
|
||||
|
||||
func TestHashMismatch(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes))
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, ErrHashMismatch)
|
||||
assert.ErrorIs(t, err, HashMismatch)
|
||||
}
|
||||
|
||||
func TestTooLongHash(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, ErrInvalidHash)
|
||||
assert.ErrorIs(t, err, InvalidHash)
|
||||
}
|
||||
|
||||
func TestTooShortHash(t *testing.T) {
|
||||
file := parseHelloWorld()
|
||||
file.Hashes.SHA256 = "5/Gy1JftyyQ"
|
||||
err := file.DecryptInPlace([]byte(helloWorldCiphertext))
|
||||
assert.ErrorIs(t, err, ErrInvalidHash)
|
||||
assert.ErrorIs(t, err, InvalidHash)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
|
|||
}
|
||||
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
|
||||
|
||||
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{
|
||||
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
|
||||
Master: masterKey,
|
||||
SelfSigning: selfKey,
|
||||
UserSigning: userKey,
|
||||
|
|
|
|||
|
|
@ -20,20 +20,6 @@ type CrossSigningPublicKeysCache struct {
|
|||
UserSigningKey id.Ed25519
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) {
|
||||
pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
|
||||
if pubkeys != nil {
|
||||
hasKeys = true
|
||||
isVerified, err = mach.CryptoStore.IsKeySignedBy(
|
||||
ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey,
|
||||
)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache {
|
||||
if mach.crossSigningPubkeys != nil {
|
||||
return mach.crossSigningPubkeys
|
||||
|
|
@ -63,8 +49,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id
|
|||
if len(dbKeys) > 0 {
|
||||
masterKey, ok := dbKeys[id.XSUsageMaster]
|
||||
if ok {
|
||||
selfSigning := dbKeys[id.XSUsageSelfSigning]
|
||||
userSigning := dbKeys[id.XSUsageUserSigning]
|
||||
selfSigning, _ := dbKeys[id.XSUsageSelfSigning]
|
||||
userSigning, _ := dbKeys[id.XSUsageUserSigning]
|
||||
return &CrossSigningPublicKeysCache{
|
||||
MasterKey: masterKey.Key,
|
||||
SelfSigningKey: selfSigning.Key,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ package crypto
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
|
|
@ -72,46 +71,6 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex
|
|||
}, passphrase)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error {
|
||||
keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get default SSSS key data: %w", err)
|
||||
}
|
||||
key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey)
|
||||
if errors.Is(err, ssss.ErrUnverifiableKey) {
|
||||
mach.machOrContextLog(ctx).Warn().
|
||||
Str("key_id", keyID).
|
||||
Msg("SSSS key is unverifiable, trying to use without verifying")
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
err = mach.FetchCrossSigningKeysFromSSSS(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
|
||||
}
|
||||
err = mach.SignOwnDevice(ctx, mach.OwnIdentity())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign own device: %w", err)
|
||||
}
|
||||
err = mach.SignOwnMasterKey(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign own master key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) {
|
||||
recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "")
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err)
|
||||
} else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil {
|
||||
err = fmt.Errorf("failed to sign own device: %w", err)
|
||||
} else if err = mach.SignOwnMasterKey(ctx); err != nil {
|
||||
err = fmt.Errorf("failed to sign own master key: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys.
|
||||
//
|
||||
// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key
|
||||
|
|
@ -138,12 +97,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
|
|||
// Publish cross-signing keys
|
||||
err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback)
|
||||
if err != nil {
|
||||
return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err)
|
||||
return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err)
|
||||
}
|
||||
|
||||
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
|
||||
if err != nil {
|
||||
return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
|
||||
return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
|
||||
}
|
||||
|
||||
return key.RecoveryKey(), keysCache, nil
|
||||
|
|
|
|||
|
|
@ -20,34 +20,36 @@ import (
|
|||
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
|
||||
log := mach.machOrContextLog(ctx)
|
||||
for userID, userKeys := range crossSigningKeys {
|
||||
log := log.With().Stringer("user_id", userID).Logger()
|
||||
log := log.With().Str("user_id", userID.String()).Logger()
|
||||
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Msg("Error fetching current cross-signing keys of user")
|
||||
}
|
||||
for curKeyUsage, curKey := range currentKeys {
|
||||
log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger()
|
||||
// got a new key with the same usage as an existing key
|
||||
for _, newKeyUsage := range userKeys.Usage {
|
||||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
log.Debug().
|
||||
Int64("signature_count", count).
|
||||
Msg("Dropped signatures made by old key as it has been replaced")
|
||||
if currentKeys != nil {
|
||||
for curKeyUsage, curKey := range currentKeys {
|
||||
log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
|
||||
// got a new key with the same usage as an existing key
|
||||
for _, newKeyUsage := range userKeys.Usage {
|
||||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
log.Debug().
|
||||
Int64("signature_count", count).
|
||||
Msg("Dropped signatures made by old key as it has been replaced")
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range userKeys.Keys {
|
||||
log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
|
||||
log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger()
|
||||
for _, usage := range userKeys.Usage {
|
||||
log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key")
|
||||
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
|
||||
|
|
|
|||
|
|
@ -225,6 +225,13 @@ func (helper *CryptoHelper) Init(ctx context.Context) error {
|
|||
helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted)
|
||||
}
|
||||
|
||||
if helper.client.SetAppServiceDeviceID {
|
||||
err = helper.mach.ShareKeys(ctx, -1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to share keys: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -261,24 +268,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error
|
|||
if !ok || len(device.Keys) == 0 {
|
||||
if isShared {
|
||||
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
|
||||
} else {
|
||||
helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
|
||||
return nil
|
||||
}
|
||||
helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys")
|
||||
err = helper.mach.ShareKeys(ctx, -1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to share keys: %w", err)
|
||||
}
|
||||
return nil
|
||||
} else if !isShared {
|
||||
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
|
||||
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
|
||||
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
|
||||
}
|
||||
if !isShared {
|
||||
helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
|
||||
} else {
|
||||
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var NoSessionFound = crypto.ErrNoSessionFound
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
|
||||
const initialSessionWaitTimeout = 3 * time.Second
|
||||
const extendedSessionWaitTimeout = 22 * time.Second
|
||||
|
|
@ -297,14 +304,24 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even
|
|||
ctx = log.WithContext(ctx)
|
||||
|
||||
decrypted, err := helper.Decrypt(ctx, evt)
|
||||
if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" {
|
||||
go helper.waitForSession(ctx, evt)
|
||||
} else if err != nil {
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err = helper.Decrypt(ctx, evt)
|
||||
} else {
|
||||
go helper.waitLongerForSession(ctx, log, evt)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to decrypt event")
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
} else {
|
||||
helper.postDecrypt(ctx, decrypted)
|
||||
return
|
||||
}
|
||||
helper.postDecrypt(ctx, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
|
||||
|
|
@ -345,33 +362,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
|
|||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
content := evt.Content.AsEncrypted()
|
||||
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err := helper.Decrypt(ctx, evt)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to decrypt event")
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
} else {
|
||||
helper.postDecrypt(ctx, decrypted)
|
||||
}
|
||||
} else {
|
||||
go helper.waitLongerForSession(ctx, evt)
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
|
||||
content := evt.Content.AsEncrypted()
|
||||
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
//lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank
|
||||
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
|
||||
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
|
|
@ -419,7 +413,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R
|
|||
defer helper.lock.RUnlock()
|
||||
encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content)
|
||||
if err != nil {
|
||||
if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) {
|
||||
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().
|
||||
|
|
|
|||
|
|
@ -24,23 +24,13 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
|
||||
ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
|
||||
ErrDuplicateMessageIndex = errors.New("duplicate megolm message index")
|
||||
ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
||||
ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
||||
ErrRatchetError = errors.New("failed to ratchet session after use")
|
||||
ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType
|
||||
NoSessionFound = ErrNoSessionFound
|
||||
DuplicateMessageIndex = ErrDuplicateMessageIndex
|
||||
WrongRoom = ErrWrongRoom
|
||||
DeviceKeyMismatch = ErrDeviceKeyMismatch
|
||||
RatchetError = ErrRatchetError
|
||||
IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent")
|
||||
NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found")
|
||||
DuplicateMessageIndex = errors.New("duplicate megolm message index")
|
||||
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
||||
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
||||
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
|
||||
RatchetError = errors.New("failed to ratchet session after use")
|
||||
)
|
||||
|
||||
type megolmEvent struct {
|
||||
|
|
@ -55,30 +45,13 @@ var (
|
|||
relatesToTopLevelPath = exgjson.Path("content", "m.relates_to")
|
||||
)
|
||||
|
||||
const sessionIDLength = 43
|
||||
|
||||
func validateCiphertextCharacters(ciphertext []byte) bool {
|
||||
for _, b := range ciphertext {
|
||||
if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
|
||||
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
||||
if !ok {
|
||||
return nil, ErrIncorrectEncryptedContentType
|
||||
return nil, IncorrectEncryptedContentType
|
||||
} else if content.Algorithm != id.AlgorithmMegolmV1 {
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
} else if len(content.MegolmCiphertext) < 74 {
|
||||
return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext))
|
||||
} else if len(content.SessionID) != sessionIDLength {
|
||||
return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID))
|
||||
} else if !validateCiphertextCharacters(content.MegolmCiphertext) {
|
||||
return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload)
|
||||
return nil, UnsupportedAlgorithm
|
||||
}
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("action", "decrypt megolm event").
|
||||
|
|
@ -124,13 +97,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
Msg("Couldn't resolve trust level of session: sent by unknown device")
|
||||
trustLevel = id.TrustStateUnknownDevice
|
||||
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
|
||||
log.Debug().
|
||||
Stringer("session_sender_key", sess.SenderKey).
|
||||
Stringer("device_sender_key", device.IdentityKey).
|
||||
Stringer("session_signing_key", sess.SigningKey).
|
||||
Stringer("device_signing_key", device.SigningKey).
|
||||
Msg("Device keys don't match keys in session, marking as untrusted")
|
||||
trustLevel = id.TrustStateDeviceKeyMismatch
|
||||
return nil, DeviceKeyMismatch
|
||||
} else {
|
||||
trustLevel, err = mach.ResolveTrustContext(ctx, device)
|
||||
if err != nil {
|
||||
|
|
@ -180,7 +147,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse megolm payload: %w", err)
|
||||
} else if megolmEvt.RoomID != encryptionRoomID {
|
||||
return nil, ErrWrongRoom
|
||||
return nil, WrongRoom
|
||||
}
|
||||
if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState {
|
||||
megolmEvt.Type.Class = event.StateEventType
|
||||
|
|
@ -213,7 +180,6 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
TrustSource: device,
|
||||
ForwardedKeys: forwardedKeys,
|
||||
WasEncrypted: true,
|
||||
EventSource: evt.Mautrix.EventSource | event.SourceDecrypted,
|
||||
ReceivedAt: evt.Mautrix.ReceivedAt,
|
||||
},
|
||||
}, nil
|
||||
|
|
@ -235,19 +201,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
|
|||
messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext)
|
||||
if decodeErr != nil {
|
||||
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
|
||||
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex)
|
||||
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
|
||||
}
|
||||
firstKnown := sess.Internal.FirstKnownIndex()
|
||||
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
|
||||
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
|
||||
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
|
||||
return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
|
||||
} else if !ok {
|
||||
log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate")
|
||||
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown)
|
||||
return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown)
|
||||
}
|
||||
log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate")
|
||||
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown)
|
||||
return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
|
||||
|
|
@ -258,11 +224,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID)
|
||||
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
||||
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
||||
return sess, nil, 0, SenderKeyMismatch
|
||||
}
|
||||
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
||||
if err != nil {
|
||||
if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
|
||||
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
|
||||
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
|
||||
return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
||||
}
|
||||
|
|
@ -270,7 +238,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
||||
} else if !ok {
|
||||
return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex)
|
||||
return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex)
|
||||
}
|
||||
|
||||
// Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function
|
||||
|
|
@ -322,24 +290,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete fully used session")
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Deleted fully used session")
|
||||
}
|
||||
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
|
||||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, ErrRatchetError
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,36 +17,21 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/ptr"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/account"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
|
||||
ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
|
||||
ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type")
|
||||
ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
|
||||
ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
|
||||
ErrSenderMismatch = errors.New("mismatched sender in olm payload")
|
||||
ErrRecipientMismatch = errors.New("mismatched recipient in olm payload")
|
||||
ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
|
||||
ErrDuplicateMessage = errors.New("duplicate olm message")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
UnsupportedAlgorithm = ErrUnsupportedAlgorithm
|
||||
NotEncryptedForMe = ErrNotEncryptedForMe
|
||||
UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType
|
||||
DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession
|
||||
DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage
|
||||
SenderMismatch = ErrSenderMismatch
|
||||
RecipientMismatch = ErrRecipientMismatch
|
||||
RecipientKeyMismatch = ErrRecipientKeyMismatch
|
||||
UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm")
|
||||
NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device")
|
||||
UnsupportedOlmMessageType = errors.New("unsupported olm message type")
|
||||
DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session")
|
||||
DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message")
|
||||
SenderMismatch = errors.New("mismatched sender in olm payload")
|
||||
RecipientMismatch = errors.New("mismatched recipient in olm payload")
|
||||
RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
|
||||
ErrDuplicateMessage = errors.New("duplicate olm message")
|
||||
)
|
||||
|
||||
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
|
||||
|
|
@ -68,13 +53,13 @@ type DecryptedOlmEvent struct {
|
|||
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
|
||||
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
||||
if !ok {
|
||||
return nil, ErrIncorrectEncryptedContentType
|
||||
return nil, IncorrectEncryptedContentType
|
||||
} else if content.Algorithm != id.AlgorithmOlmV1 {
|
||||
return nil, ErrUnsupportedAlgorithm
|
||||
return nil, UnsupportedAlgorithm
|
||||
}
|
||||
ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()]
|
||||
if !ok {
|
||||
return nil, ErrNotEncryptedForMe
|
||||
return nil, NotEncryptedForMe
|
||||
}
|
||||
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
|
||||
if err != nil {
|
||||
|
|
@ -90,7 +75,7 @@ type OlmEventKeys struct {
|
|||
|
||||
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
|
||||
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
|
||||
return nil, ErrUnsupportedOlmMessageType
|
||||
return nil, UnsupportedOlmMessageType
|
||||
}
|
||||
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
|
|
@ -114,11 +99,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
|
|||
}
|
||||
olmEvt.Type.Class = evt.Type.Class
|
||||
if evt.Sender != olmEvt.Sender {
|
||||
return nil, ErrSenderMismatch
|
||||
return nil, SenderMismatch
|
||||
} else if mach.Client.UserID != olmEvt.Recipient {
|
||||
return nil, ErrRecipientMismatch
|
||||
return nil, RecipientMismatch
|
||||
} else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 {
|
||||
return nil, ErrRecipientKeyMismatch
|
||||
return nil, RecipientKeyMismatch
|
||||
}
|
||||
|
||||
if len(olmEvt.Content.VeryRaw) > 0 {
|
||||
|
|
@ -134,9 +119,6 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
|
|||
}
|
||||
|
||||
func olmMessageHash(ciphertext string) ([32]byte, error) {
|
||||
if ciphertext == "" {
|
||||
return [32]byte{}, fmt.Errorf("empty ciphertext")
|
||||
}
|
||||
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
|
||||
return sha256.Sum256(ciphertextBytes), err
|
||||
}
|
||||
|
|
@ -166,7 +148,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
|
||||
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
|
||||
if err != nil {
|
||||
if err == ErrDecryptionFailedWithMatchingSession {
|
||||
if err == DecryptionFailedWithMatchingSession {
|
||||
log.Warn().Msg("Found matching session, but decryption failed")
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
}
|
||||
|
|
@ -184,10 +166,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
// if it isn't one at this point in time anymore, so return early.
|
||||
if olmType != id.OlmMsgTypePreKey {
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
return nil, ErrDecryptionFailedForNormalMessage
|
||||
return nil, DecryptionFailedForNormalMessage
|
||||
}
|
||||
|
||||
accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp"))
|
||||
log.Trace().Msg("Trying to create inbound session")
|
||||
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
|
||||
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
|
||||
|
|
@ -199,7 +180,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
|
||||
log.Debug().
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Created inbound olm session")
|
||||
ctx = log.WithContext(ctx)
|
||||
|
|
@ -208,19 +188,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
plaintext, err = session.Decrypt(ciphertext, olmType)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Debug().
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
|
||||
Str("ciphertext", ciphertext).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("DEBUG: Failed to decrypt prekey olm message with newly created session")
|
||||
err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup)
|
||||
if err2 != nil {
|
||||
log.Debug().Err(err2).Msg("Goolm confirmed decryption failure")
|
||||
} else {
|
||||
log.Warn().Msg("Goolm decryption was successful after libolm failure?")
|
||||
}
|
||||
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
|
||||
}
|
||||
|
|
@ -238,23 +205,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error {
|
||||
acc, err := account.AccountFromPickled(accountBackup, []byte("tmp"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unpickle olm account: %w", err)
|
||||
}
|
||||
sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create inbound session: %w", err)
|
||||
}
|
||||
_, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey)
|
||||
if err != nil {
|
||||
// This is the expected result if libolm failed
|
||||
return fmt.Errorf("failed to decrypt with new session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const MaxOlmSessionsPerDevice = 5
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
|
||||
|
|
@ -313,11 +263,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
|
|||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]).
|
||||
Str("session_description", session.Describe()).
|
||||
Msg("Failed to decrypt olm message")
|
||||
if olmType == id.OlmMsgTypePreKey {
|
||||
return nil, ErrDecryptionFailedWithMatchingSession
|
||||
return nil, DecryptionFailedWithMatchingSession
|
||||
}
|
||||
} else {
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
|
||||
|
|
@ -357,10 +306,10 @@ const MinUnwedgeInterval = 1 * time.Hour
|
|||
|
||||
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
|
||||
log = log.With().Str("action", "unwedge olm session").Logger()
|
||||
ctx := log.WithContext(mach.backgroundCtx)
|
||||
ctx := log.WithContext(mach.BackgroundCtx)
|
||||
mach.recentlyUnwedgedLock.Lock()
|
||||
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
|
||||
delta := time.Since(prevUnwedge)
|
||||
delta := time.Now().Sub(prevUnwedge)
|
||||
if ok && delta < MinUnwedgeInterval {
|
||||
log.Debug().
|
||||
Str("previous_recreation", delta.String()).
|
||||
|
|
@ -391,10 +340,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send
|
|||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Time("last_created", lastCreatedAt).
|
||||
Stringer("device_id", deviceIdentity.DeviceID).
|
||||
Msg("Creating new Olm session")
|
||||
log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
|
||||
mach.devicesToUnwedgeLock.Lock()
|
||||
mach.devicesToUnwedge[senderKey] = true
|
||||
mach.devicesToUnwedgeLock.Unlock()
|
||||
|
|
|
|||
|
|
@ -22,23 +22,14 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
|
||||
ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
|
||||
ErrMismatchingSigningKey = errors.New("received update for device with different signing key")
|
||||
ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key")
|
||||
ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
|
||||
ErrInvalidKeySignature = errors.New("invalid signature on device keys")
|
||||
ErrUserNotTracked = errors.New("user is not tracked")
|
||||
)
|
||||
MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object")
|
||||
MismatchingUserID = errors.New("mismatching user ID in parameter and keys object")
|
||||
MismatchingSigningKey = errors.New("received update for device with different signing key")
|
||||
NoSigningKeyFound = errors.New("didn't find ed25519 signing key")
|
||||
NoIdentityKeyFound = errors.New("didn't find curve25519 identity key")
|
||||
InvalidKeySignature = errors.New("invalid signature on device keys")
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
MismatchingDeviceID = ErrMismatchingDeviceID
|
||||
MismatchingUserID = ErrMismatchingUserID
|
||||
MismatchingSigningKey = ErrMismatchingSigningKey
|
||||
NoSigningKeyFound = ErrNoSigningKeyFound
|
||||
NoIdentityKeyFound = ErrNoIdentityKeyFound
|
||||
InvalidKeySignature = ErrInvalidKeySignature
|
||||
ErrUserNotTracked = errors.New("user is not tracked")
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
|
||||
|
|
@ -215,7 +206,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
|
|||
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
|
||||
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
|
||||
for userID, devices := range resp.DeviceKeys {
|
||||
log := log.With().Stringer("user_id", userID).Logger()
|
||||
log := log.With().Str("user_id", userID.String()).Logger()
|
||||
delete(req.DeviceKeys, userID)
|
||||
|
||||
newDevices := make(map[id.DeviceID]*id.Device)
|
||||
|
|
@ -231,7 +222,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
|
|||
Msg("Updating devices in store")
|
||||
changed := false
|
||||
for deviceID, deviceKeys := range devices {
|
||||
log := log.With().Stringer("device_id", deviceID).Logger()
|
||||
log := log.With().Str("device_id", deviceID.String()).Logger()
|
||||
existing, ok := existingDevices[deviceID]
|
||||
if !ok {
|
||||
// New device
|
||||
|
|
@ -279,7 +270,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ
|
|||
}
|
||||
}
|
||||
for userID := range req.DeviceKeys {
|
||||
log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user")
|
||||
log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
|
||||
}
|
||||
|
||||
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
|
||||
|
|
@ -321,28 +312,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID)
|
|||
|
||||
func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) {
|
||||
if deviceID != deviceKeys.DeviceID {
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID)
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID)
|
||||
} else if userID != deviceKeys.UserID {
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID)
|
||||
return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID)
|
||||
}
|
||||
|
||||
signingKey := deviceKeys.Keys.GetEd25519(deviceID)
|
||||
identityKey := deviceKeys.Keys.GetCurve25519(deviceID)
|
||||
if signingKey == "" {
|
||||
return nil, ErrNoSigningKeyFound
|
||||
return nil, NoSigningKeyFound
|
||||
} else if identityKey == "" {
|
||||
return nil, ErrNoIdentityKeyFound
|
||||
return nil, NoIdentityKeyFound
|
||||
}
|
||||
|
||||
if existing != nil && existing.SigningKey != signingKey {
|
||||
return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey)
|
||||
return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
|
||||
}
|
||||
|
||||
ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
|
||||
if err != nil {
|
||||
return existing, fmt.Errorf("failed to verify signature: %w", err)
|
||||
} else if !ok {
|
||||
return existing, ErrInvalidKeySignature
|
||||
return existing, InvalidKeySignature
|
||||
}
|
||||
|
||||
name, ok := deviceKeys.Unsigned["device_display_name"].(string)
|
||||
|
|
|
|||
|
|
@ -25,12 +25,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrNoGroupSession = errors.New("no group session created")
|
||||
)
|
||||
|
||||
// Deprecated: use variables prefixed with Err
|
||||
var (
|
||||
NoGroupSession = ErrNoGroupSession
|
||||
NoGroupSession = errors.New("no group session created")
|
||||
)
|
||||
|
||||
func getRawJSON[T any](content json.RawMessage, path ...string) *T {
|
||||
|
|
@ -46,7 +41,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T {
|
|||
return &result
|
||||
}
|
||||
|
||||
func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
|
||||
func getRelatesTo(content any) *event.RelatesTo {
|
||||
contentJSON, ok := content.(json.RawMessage)
|
||||
if ok {
|
||||
return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to")
|
||||
|
|
@ -59,7 +54,7 @@ func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo {
|
|||
if ok {
|
||||
return relatable.OptionalGetRelatesTo()
|
||||
}
|
||||
return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to")
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMentions(content any) *event.Mentions {
|
||||
|
|
@ -87,20 +82,15 @@ type rawMegolmEvent struct {
|
|||
|
||||
// IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession
|
||||
func IsShareError(err error) bool {
|
||||
return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession
|
||||
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
|
||||
}
|
||||
|
||||
func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return 0, fmt.Errorf("empty ciphertext")
|
||||
}
|
||||
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
|
||||
var err error
|
||||
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if len(decoded) < 2+binary.MaxVarintLen64 {
|
||||
return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded))
|
||||
} else if decoded[0] != 3 || decoded[1] != 8 {
|
||||
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
|
||||
}
|
||||
|
|
@ -130,7 +120,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
||||
} else if session == nil {
|
||||
return nil, ErrNoGroupSession
|
||||
return nil, NoGroupSession
|
||||
}
|
||||
plaintext, err := json.Marshal(&rawMegolmEvent{
|
||||
RoomID: roomID,
|
||||
|
|
@ -168,21 +158,12 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room
|
|||
Algorithm: id.AlgorithmMegolmV1,
|
||||
SessionID: session.ID(),
|
||||
MegolmCiphertext: ciphertext,
|
||||
RelatesTo: getRelatesTo(content, plaintext),
|
||||
RelatesTo: getRelatesTo(content),
|
||||
|
||||
// These are deprecated
|
||||
SenderKey: mach.account.IdentityKey(),
|
||||
DeviceID: mach.Client.DeviceID,
|
||||
}
|
||||
if mach.MSC4392Relations && encrypted.RelatesTo != nil {
|
||||
// When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content.
|
||||
// Other relations like threads are still left unencrypted.
|
||||
encrypted.RelatesTo.InReplyTo = nil
|
||||
encrypted.RelatesTo.IsFallingBack = false
|
||||
if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" {
|
||||
encrypted.RelatesTo = nil
|
||||
}
|
||||
}
|
||||
if mach.PlaintextMentions {
|
||||
encrypted.Mentions = getMentions(content)
|
||||
}
|
||||
|
|
@ -252,7 +233,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
|||
var fetchKeysForUsers []id.UserID
|
||||
|
||||
for _, userID := range users {
|
||||
log := log.With().Stringer("target_user_id", userID).Logger()
|
||||
log := log.With().Str("target_user_id", userID.String()).Logger()
|
||||
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get devices of user")
|
||||
|
|
@ -324,7 +305,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
|||
toDeviceWithheld.Messages[userID] = withheld
|
||||
}
|
||||
|
||||
log := log.With().Stringer("target_user_id", userID).Logger()
|
||||
log := log.With().Str("target_user_id", userID.String()).Logger()
|
||||
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
|
||||
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
|
||||
log.Debug().
|
||||
|
|
@ -370,19 +351,26 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
|
|||
log.Trace().Msg("Encrypting group session for all found devices")
|
||||
deviceCount := 0
|
||||
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
||||
logUsers := zerolog.Dict()
|
||||
for userID, sessions := range olmSessions {
|
||||
if len(sessions) == 0 {
|
||||
continue
|
||||
}
|
||||
logDevices := zerolog.Dict()
|
||||
output := make(map[id.DeviceID]*event.Content)
|
||||
toDevice.Messages[userID] = output
|
||||
for deviceID, device := range sessions {
|
||||
log.Trace().
|
||||
Stringer("target_user_id", userID).
|
||||
Stringer("target_device_id", deviceID).
|
||||
Stringer("target_identity_key", device.identity.IdentityKey).
|
||||
Msg("Encrypting group session for device")
|
||||
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
|
||||
output[deviceID] = &event.Content{Parsed: content}
|
||||
logDevices.Str(string(deviceID), string(device.identity.IdentityKey))
|
||||
deviceCount++
|
||||
log.Debug().
|
||||
Stringer("target_user_id", userID).
|
||||
Stringer("target_device_id", deviceID).
|
||||
Stringer("target_identity_key", device.identity.IdentityKey).
|
||||
Msg("Encrypted group session for device")
|
||||
if !mach.DisableSharedGroupSessionTracking {
|
||||
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
|
||||
if err != nil {
|
||||
|
|
@ -396,13 +384,11 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
|
|||
}
|
||||
}
|
||||
}
|
||||
logUsers.Dict(string(userID), logDevices)
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Int("device_count", deviceCount).
|
||||
Int("user_count", len(toDevice.Messages)).
|
||||
Dict("destination_map", logUsers).
|
||||
Msg("Sending to-device messages to share group session")
|
||||
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -96,19 +96,15 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
|||
panic(err)
|
||||
}
|
||||
log := mach.machOrContextLog(ctx)
|
||||
log.Debug().
|
||||
Str("recipient_identity_key", recipient.IdentityKey.String()).
|
||||
Str("olm_session_id", session.ID().String()).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Encrypting olm message")
|
||||
msgType, ciphertext, err := session.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ciphertextStr := string(ciphertext)
|
||||
ciphertextHash, _ := olmMessageHash(ciphertextStr)
|
||||
log.Debug().
|
||||
Stringer("event_type", evtType).
|
||||
Str("recipient_identity_key", recipient.IdentityKey.String()).
|
||||
Str("olm_session_id", session.ID().String()).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Msg("Encrypted olm message")
|
||||
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
|
||||
|
|
@ -119,7 +115,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
|||
OlmCiphertext: event.OlmCiphertexts{
|
||||
recipient.IdentityKey: {
|
||||
Type: msgType,
|
||||
Body: ciphertextStr,
|
||||
Body: string(ciphertext),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
|
||||
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion)
|
||||
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
|
||||
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
|
||||
return err
|
||||
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) {
|
|||
account, err := account.NewAccount()
|
||||
assert.NoError(t, err)
|
||||
err = account.Unpickle(pickled, pickleKey)
|
||||
assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion)
|
||||
assert.ErrorIs(t, err, olm.ErrBadVersion)
|
||||
}
|
||||
|
||||
func TestLoopback(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
"maunium.net/go/mautrix/crypto/olm"
|
||||
)
|
||||
|
||||
func Register() {
|
||||
func init() {
|
||||
olm.InitNewAccount = func() (olm.Account, error) {
|
||||
return NewAccount()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
|
|||
|
||||
// SharedSecret returns the shared secret between the key pair and the given public key.
|
||||
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
|
||||
// Note: the standard library checks that the output is non-zero
|
||||
return c.PrivateKey.SharedSecret(pubKey)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,6 @@ func TestCurve25519(t *testing.T) {
|
|||
fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, fromPrivate, firstKeypair)
|
||||
_, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCurve25519Case1(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ import (
|
|||
"encoding/base64"
|
||||
)
|
||||
|
||||
// These methods should only be used for raw byte operations, never with string conversion
|
||||
|
||||
// Deprecated: base64.RawStdEncoding should be used directly
|
||||
func Decode(input []byte) ([]byte, error) {
|
||||
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
|
||||
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
|
||||
|
|
@ -15,6 +14,7 @@ func Decode(input []byte) ([]byte, error) {
|
|||
return decoded[:writtenBytes], nil
|
||||
}
|
||||
|
||||
// Deprecated: base64.RawStdEncoding should be used directly
|
||||
func Encode(input []byte) []byte {
|
||||
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
|
||||
base64.RawStdEncoding.Encode(encoded, input)
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
|
|||
}
|
||||
}
|
||||
if decrypted[0] != pickleVersion {
|
||||
return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion)
|
||||
return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion)
|
||||
}
|
||||
err = json.Unmarshal(decrypted[1:], object)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -3,9 +3,6 @@ package message
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
)
|
||||
|
||||
type Decoder struct {
|
||||
|
|
@ -23,8 +20,6 @@ func (d *Decoder) ReadVarInt() (uint64, error) {
|
|||
func (d *Decoder) ReadVarBytes() ([]byte, error) {
|
||||
if n, err := d.ReadVarInt(); err != nil {
|
||||
return nil, err
|
||||
} else if n > uint64(d.Len()) {
|
||||
return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available())
|
||||
} else {
|
||||
out := make([]byte, n)
|
||||
_, err = d.Read(out)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,10 @@ package message
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/aessha2"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -38,9 +36,6 @@ func (r *GroupMessage) Decode(input []byte) (err error) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
if r.Version != protocolVersion {
|
||||
return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion)
|
||||
}
|
||||
|
||||
for {
|
||||
// Read Key
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue